distribution/vendor/github.com/go-llsqlite/crawshaw/sqlitex/exec.go

528 lines
14 KiB
Go

// Copyright (c) 2018 David Crawshaw <david@zentus.com>
// Copyright (c) 2021 Ross Light <rosss@zombiezen.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
//
// SPDX-License-Identifier: ISC
// Package sqlitex provides utilities for working with SQLite.
package sqlitex
import (
"fmt"
"io"
"reflect"
"strings"
"github.com/go-llsqlite/crawshaw"
"io/fs"
)
// ExecOptions is the set of optional arguments executing a statement.
type ExecOptions struct {
// Args is the set of positional arguments to bind to the statement.
// The first element in the slice is ?1.
// See https://sqlite.org/lang_expr.html for more details.
//
// Basic reflection on Args is used to map:
//
// integers to BindInt64
// floats to BindFloat
// []byte to BindBytes
// string to BindText
// bool to BindBool
//
// All other kinds are printed using fmt.Sprint(v) and passed to BindText.
Args []interface{}
// Named is the set of named arguments to bind to the statement. Keys must
// start with ':', '@', or '$'. See https://sqlite.org/lang_expr.html for more
// details.
//
// Basic reflection on Named is used to map:
//
// integers to BindInt64
// floats to BindFloat
// []byte to BindBytes
// string to BindText
// bool to BindBool
//
// All other kinds are printed using fmt.Sprint(v) and passed to BindText.
Named map[string]interface{}
// ResultFunc is called for each result row.
// If ResultFunc returns an error then iteration ceases
// and the execution function returns the error value.
ResultFunc func(stmt *sqlite.Stmt) error
// Allow unused parameters. SQLite normally treats these as null anyway, so this reverts to the
// default behaviour.
AllowUnused bool
}
// Exec executes an SQLite query.
//
// For each result row, the resultFn is called.
// Result values can be read by resultFn using stmt.Column* methods.
// If resultFn returns an error then iteration ceases and Exec returns
// the error value.
//
// Any args provided to Exec are bound to numbered parameters of the
// query using the Stmt Bind* methods. Basic reflection on args is used
// to map:
//
// integers to BindInt64
// floats to BindFloat
// []byte to BindBytes
// string to BindText
// bool to BindBool
//
// All other kinds are printed using fmt.Sprintf("%v", v) and passed
// to BindText.
//
// Exec is implemented using the Stmt prepare mechanism which allows
// better interactions with Go's type system and avoids pitfalls of
// passing a Go closure to cgo.
//
// As Exec is implemented using Conn.Prepare, subsequent calls to Exec
// with the same statement will reuse the cached statement object.
//
// Deprecated: Use Execute.
// Exec skips some argument checks for compatibility with crawshaw.io/sqlite.
func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
stmt, err := conn.Prepare(query)
if err != nil {
return annotateErr(err)
}
err = exec(stmt, 0, &ExecOptions{
Args: args,
ResultFunc: resultFn,
})
resetErr := stmt.Reset()
if err == nil {
err = resetErr
}
return err
}
// Execute executes an SQLite query.
//
// As Execute is implemented using Conn.Prepare,
// subsequent calls to Execute with the same statement
// will reuse the cached statement object.
func Execute(conn *sqlite.Conn, query string, opts *ExecOptions) error {
stmt, err := conn.Prepare(query)
if err != nil {
return annotateErr(err)
}
err = exec(stmt, forbidMissing|forbidExtra, opts)
resetErr := stmt.Reset()
if err == nil {
err = resetErr
}
return err
}
// ExecFS is an alias for ExecuteFS.
//
// Deprecated: Call ExecuteFS directly.
func ExecFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
return ExecuteFS(conn, fsys, filename, opts)
}
// ExecuteFS executes the single statement in the given SQL file.
// ExecuteFS is implemented using Conn.Prepare,
// so subsequent calls to ExecuteFS with the same statement
// will reuse the cached statement object.
func ExecuteFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
query, err := readString(fsys, filename)
if err != nil {
return fmt.Errorf("exec: %w", err)
}
stmt, err := conn.Prepare(strings.TrimSpace(query))
if err != nil {
return fmt.Errorf("exec %s: %w", filename, err)
}
err = exec(stmt, forbidMissing|forbidExtra, opts)
resetErr := stmt.Reset()
if err != nil {
// Don't strip the error query: we already do this inside exec.
return fmt.Errorf("exec %s: %w", filename, err)
}
if resetErr != nil {
return fmt.Errorf("exec %s: %w", filename, err)
}
return nil
}
// ExecTransient executes an SQLite query without caching the underlying query.
// The interface is exactly the same as Exec.
// It is the spiritual equivalent of sqlite3_exec.
//
// Deprecated: Use ExecuteTransient.
// ExecTransient skips some argument checks for compatibility with crawshaw.io/sqlite.
func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(query)
if err != nil {
return annotateErr(err)
}
defer func() {
ferr := stmt.Finalize()
if err == nil {
err = ferr
}
}()
if trailingBytes != 0 {
return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
}
return exec(stmt, 0, &ExecOptions{
Args: args,
ResultFunc: resultFn,
})
}
// ExecuteTransient executes an SQLite query without caching the underlying query.
// It is the spiritual equivalent of sqlite3_exec:
// https://www.sqlite.org/c3ref/exec.html
func ExecuteTransient(conn *sqlite.Conn, query string, opts *ExecOptions) (err error) {
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(query)
if err != nil {
return annotateErr(err)
}
defer func() {
ferr := stmt.Finalize()
if err == nil {
err = ferr
}
}()
if trailingBytes != 0 {
return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
}
return exec(stmt, forbidMissing|forbidExtra, opts)
}
// ExecTransientFS is an alias for ExecuteTransientFS.
//
// Deprecated: Call ExecuteTransientFS directly.
func ExecTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
return ExecuteTransientFS(conn, fsys, filename, opts)
}
// ExecuteTransientFS executes the single statement in the given SQL file without
// caching the underlying query.
func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
query, err := readString(fsys, filename)
if err != nil {
return fmt.Errorf("exec: %w", err)
}
stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
if err != nil {
return fmt.Errorf("exec %s: %w", filename, err)
}
defer stmt.Finalize()
err = exec(stmt, forbidMissing|forbidExtra, opts)
resetErr := stmt.Reset()
if err != nil {
// Don't strip the error query: we already do this inside exec.
return fmt.Errorf("exec %s: %w", filename, err)
}
if resetErr != nil {
return fmt.Errorf("exec %s: %w", filename, err)
}
return nil
}
// PrepareTransientFS prepares an SQL statement from a file that is not cached by
// the Conn. Subsequent calls with the same query will create new Stmts.
// The caller is responsible for calling Finalize on the returned Stmt when the
// Stmt is no longer needed.
func PrepareTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string) (*sqlite.Stmt, error) {
query, err := readString(fsys, filename)
if err != nil {
return nil, fmt.Errorf("prepare: %w", err)
}
stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
if err != nil {
return nil, fmt.Errorf("prepare %s: %w", filename, err)
}
return stmt, nil
}
const (
forbidMissing = 1 << iota
forbidExtra
)
func exec(stmt *sqlite.Stmt, flags uint8, opts *ExecOptions) (err error) {
paramCount := stmt.BindParamCount()
provided := newBitset(paramCount)
if opts != nil {
if len(opts.Args) > paramCount {
return fmt.Errorf("sqlitex.Exec: %w (len(Args) > BindParamCount(); %d > %d)",
sqlite.ResultRange.ToError(), len(opts.Args), paramCount)
}
for i, arg := range opts.Args {
provided.set(i)
setArg(stmt, i+1, reflect.ValueOf(arg))
}
if err := setNamed(stmt, provided, flags, opts.Named); err != nil {
return err
}
}
if flags&forbidMissing != 0 && !provided.hasAll(paramCount) {
i := provided.firstMissing() + 1
name := stmt.BindParamName(i)
if name == "" {
name = fmt.Sprintf("?%d", i)
}
return fmt.Errorf("sqlitex.Exec: missing argument for %s", name)
}
for {
hasRow, err := stmt.Step()
if err != nil {
return err
}
if !hasRow {
break
}
if opts != nil && opts.ResultFunc != nil {
if err := opts.ResultFunc(stmt); err != nil {
return err
}
}
}
return nil
}
func setArg(stmt *sqlite.Stmt, i int, v reflect.Value) {
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
stmt.BindInt64(i, v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
stmt.BindInt64(i, int64(v.Uint()))
case reflect.Float32, reflect.Float64:
stmt.BindFloat(i, v.Float())
case reflect.String:
stmt.BindText(i, v.String())
case reflect.Bool:
stmt.BindBool(i, v.Bool())
case reflect.Invalid:
stmt.BindNull(i)
default:
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
stmt.BindBytes(i, v.Bytes())
} else {
stmt.BindText(i, fmt.Sprint(v.Interface()))
}
}
}
func setNamed(stmt *sqlite.Stmt, provided bitset, flags uint8, args map[string]interface{}) error {
if len(args) == 0 {
return nil
}
var unused map[string]struct{}
if flags&forbidExtra != 0 {
unused = make(map[string]struct{}, len(args))
for k := range args {
unused[k] = struct{}{}
}
}
for i, count := 1, stmt.BindParamCount(); i <= count; i++ {
name := stmt.BindParamName(i)
if name == "" {
continue
}
arg, present := args[name]
if !present {
if flags&forbidMissing != 0 {
// TODO(maybe): Check provided as well?
return fmt.Errorf("missing parameter %s", name)
}
continue
}
delete(unused, name)
provided.set(i - 1)
setArg(stmt, i, reflect.ValueOf(arg))
}
if len(unused) > 0 {
return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
}
return nil
}
func annotateErr(err error) error {
// TODO(maybe)
// if err, isError := err.(sqlite.Error); isError {
// if err.Loc == "" {
// err.Loc = "Exec"
// } else {
// err.Loc = "Exec: " + err.Loc
// }
// return err
// }
return fmt.Errorf("sqlitex.Exec: %w", err)
}
// ExecScript executes a script of SQL statements.
// It is the same as calling ExecuteScript without options.
func ExecScript(conn *sqlite.Conn, queries string) (err error) {
return ExecuteScript(conn, queries, nil)
}
// ExecuteScript executes a script of SQL statements.
// The script is wrapped in a SAVEPOINT transaction,
// which is rolled back on any error.
//
// opts.ResultFunc is ignored.
func ExecuteScript(conn *sqlite.Conn, queries string, opts *ExecOptions) (err error) {
defer Save(conn)(&err)
unused := make(map[string]struct{})
if opts != nil {
for k := range opts.Named {
unused[k] = struct{}{}
}
}
for {
queries = strings.TrimSpace(queries)
if queries == "" {
break
}
var stmt *sqlite.Stmt
var trailingBytes int
stmt, trailingBytes, err = conn.PrepareTransient(queries)
if err != nil {
return err
}
for i, n := 1, stmt.BindParamCount(); i <= n; i++ {
if name := stmt.BindParamName(i); name != "" {
delete(unused, name)
}
}
usedBytes := len(queries) - trailingBytes
queries = queries[usedBytes:]
err = exec(stmt, forbidMissing, opts)
stmt.Finalize()
if err != nil {
return err
}
}
if len(unused) > 0 && !opts.AllowUnused {
return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
}
return nil
}
// ExecScriptFS is an alias for ExecuteScriptFS.
//
// Deprecated: Call ExecuteScriptFS directly.
func ExecScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
return ExecuteScriptFS(conn, fsys, filename, opts)
}
// ExecuteScriptFS executes a script of SQL statements from a file.
// The script is wrapped in a SAVEPOINT transaction,
// which is rolled back on any error.
func ExecuteScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
queries, err := readString(fsys, filename)
if err != nil {
return fmt.Errorf("exec: %w", err)
}
if err := ExecuteScript(conn, queries, opts); err != nil {
return fmt.Errorf("exec %s: %w", filename, err)
}
return nil
}
type bitset []uint64
func newBitset(n int) bitset {
return make([]uint64, (n+63)/64)
}
// hasAll reports whether the bitset is a superset of [0, n).
func (bs bitset) hasAll(n int) bool {
nbytes := (n + 63) / 64
if len(bs) < nbytes {
return false
}
fullBytes := n / 64
for _, b := range bs[:fullBytes] {
if b != ^uint64(0) {
return false
}
}
if fullBytes == nbytes {
return true
}
mask := uint64(1)<<(n%64) - 1
return bs[nbytes-1]&mask == mask
}
func (bs bitset) firstMissing() int {
for i, b := range bs {
if b == ^uint64(0) {
continue
}
for j := 0; j < 64; j++ {
if b&(1<<j) == 0 {
return i*64 + j
}
}
}
return len(bs) * 64
}
func (bs bitset) set(n int) {
bs[n/64] |= 1 << (n % 64)
}
func (bs bitset) String() string {
sb := new(strings.Builder)
for i := len(bs) - 1; i >= 0; i-- {
fmt.Fprintf(sb, "%08b", bs[i])
}
return sb.String()
}
func minStringInSet(set map[string]struct{}) string {
min := ""
for k := range set {
if min == "" || k < min {
min = k
}
}
return min
}
func readString(fsys fs.FS, filename string) (string, error) {
f, err := fsys.Open(filename)
if err != nil {
return "", err
}
content := new(strings.Builder)
_, err = io.Copy(content, f)
f.Close()
if err != nil {
return "", fmt.Errorf("%s: %w", filename, err)
}
return content.String(), nil
}