diff --git a/const.go b/const.go index 5697b31..5c46934 100644 --- a/const.go +++ b/const.go @@ -169,7 +169,8 @@ const ( PREPARE_NO_VTAB PrepareFlag = 0x04 ) -// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags]. +// FunctionFlag is a flag that can be passed to +// [Conn.CreateFunction] and [Conn.CreateWindowFunction]. // // https://sqlite.org/c3ref/c_deterministic.html type FunctionFlag uint32 @@ -181,6 +182,23 @@ const ( INNOCUOUS FunctionFlag = 0x000200000 ) +// StmtStatus name counter values associated with the [Stmt.Status] method. +// +// https://sqlite.org/c3ref/c_stmtstatus_counter.html +type StmtStatus uint32 + +const ( + STMTSTATUS_FULLSCAN_STEP StmtStatus = 1 + STMTSTATUS_SORT StmtStatus = 2 + STMTSTATUS_AUTOINDEX StmtStatus = 3 + STMTSTATUS_VM_STEP StmtStatus = 4 + STMTSTATUS_REPREPARE StmtStatus = 5 + STMTSTATUS_RUN StmtStatus = 6 + STMTSTATUS_FILTER_MISS StmtStatus = 7 + STMTSTATUS_FILTER_HIT StmtStatus = 8 + STMTSTATUS_MEMUSED StmtStatus = 99 +) + // Datatype is a fundamental datatype of SQLite. // // https://sqlite.org/c3ref/c_blob.html diff --git a/driver/driver.go b/driver/driver.go index 564786c..ef08aa7 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -250,7 +250,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e s.Close() return nil, util.TailErr } - return &stmt{s, c.Conn}, nil + return &stmt{s}, nil } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -281,8 +281,7 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error { } type stmt struct { - Stmt *sqlite3.Stmt - Conn *sqlite3.Conn + *sqlite3.Stmt } var ( @@ -292,10 +291,6 @@ var ( _ driver.NamedValueChecker = &stmt{} ) -func (s *stmt) Close() error { - return s.Stmt.Close() -} - func (s *stmt) NumInput() int { n := s.Stmt.BindCount() for i := 1; i <= n; i++ { @@ -322,15 +317,15 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive return nil, err } - old := s.Conn.SetInterrupt(ctx) - defer s.Conn.SetInterrupt(old) + old := s.Stmt.Conn().SetInterrupt(ctx) + defer s.Stmt.Conn().SetInterrupt(old) err = s.Stmt.Exec() if err != nil { return nil, err } - return newResult(s.Conn), nil + return newResult(s.Stmt.Conn()), nil } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -338,7 +333,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv if err != nil { return nil, err } - return &rows{ctx, s.Stmt, s.Conn}, nil + return &rows{ctx, s.Stmt}, nil } func (s *stmt) setupBindings(args []driver.NamedValue) error { @@ -442,10 +437,10 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { type rows struct { ctx context.Context Stmt *sqlite3.Stmt - Conn *sqlite3.Conn } func (r *rows) Close() error { + r.Stmt.ClearBindings() return r.Stmt.Reset() } @@ -469,8 +464,8 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { } func (r *rows) Next(dest []driver.Value) error { - old := r.Conn.SetInterrupt(r.ctx) - defer r.Conn.SetInterrupt(old) + old := r.Stmt.Conn().SetInterrupt(r.ctx) + defer r.Stmt.Conn().SetInterrupt(old) if !r.Stmt.Step() { if err := r.Stmt.Err(); err != nil { diff --git a/embed/exports.txt b/embed/exports.txt index 88310f0..ddd284d 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -74,7 +74,9 @@ sqlite3_result_value sqlite3_result_zeroblob64 sqlite3_set_auxdata_go sqlite3_step +sqlite3_stmt_busy sqlite3_stmt_readonly +sqlite3_stmt_status sqlite3_uri_key sqlite3_uri_parameter sqlite3_user_data diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 47c4d10..3f519dc 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 87d173d..4910f25 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -172,9 +172,9 @@ func (t *table) newReader() *csv.Reader { type cursor struct { table *table - rowID int64 - row []string csv *csv.Reader + row []string + rowID int64 } func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { diff --git a/ext/lines/lines.go b/ext/lines/lines.go index 638715e..1f7e360 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -52,11 +52,11 @@ func (l lines) Open() (sqlite3.VTabCursor, error) { } type cursor struct { - reader bool scanner *bufio.Scanner closer io.Closer rowID int64 eof bool + reader bool } func (c *cursor) Close() (err error) { diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go index 55d5df2..81ef4e3 100644 --- a/ext/statement/stmt.go +++ b/ext/statement/stmt.go @@ -26,12 +26,10 @@ func Register(db *sqlite3.Conn) { sql = sql[1 : len-1] } - table := &table{ - db: db, - sql: sql, - } - err = table.declare() + table := &table{sql: sql} + err = table.declare(db) if err != nil { + table.Close() return nil, err } return table, nil @@ -41,42 +39,40 @@ func Register(db *sqlite3.Conn) { } type table struct { - db *sqlite3.Conn - sql string - inputs int - outputs int + stmt *sqlite3.Stmt + sql string + inuse bool } -func (t *table) declare() error { - stmt, tail, err := t.db.Prepare(t.sql) +func (t *table) declare(db *sqlite3.Conn) (err error) { + var tail string + t.stmt, tail, err = db.Prepare(t.sql) if err != nil { return err } - defer stmt.Close() if tail != "" { return fmt.Errorf("statement: multiple statements") } - if !stmt.ReadOnly() { + if !t.stmt.ReadOnly() { return fmt.Errorf("statement: statement must be read only") } - t.inputs = stmt.BindCount() - t.outputs = stmt.ColumnCount() - var sep = "" var str strings.Builder str.WriteString(`CREATE TABLE x(`) - for i := 0; i < t.outputs; i++ { + outputs := t.stmt.ColumnCount() + for i := 0; i < outputs; i++ { str.WriteString(sep) - name := stmt.ColumnName(i) + name := t.stmt.ColumnName(i) str.WriteString(sqlite3.QuoteIdentifier(name)) str.WriteByte(' ') - str.WriteString(stmt.ColumnDeclType(i)) + str.WriteString(t.stmt.ColumnDeclType(i)) sep = "," } - for i := 1; i <= t.inputs; i++ { + inputs := t.stmt.BindCount() + for i := 1; i <= inputs; i++ { str.WriteString(sep) - name := stmt.BindName(i) + name := t.stmt.BindName(i) if name == "" { str.WriteString("[") str.WriteString(strconv.Itoa(i)) @@ -87,22 +83,24 @@ func (t *table) declare() error { } sep = "," } - str.WriteByte(')') - return t.db.DeclareVtab(str.String()) + return db.DeclareVtab(str.String()) +} + +func (t *table) Close() error { + return t.stmt.Close() } func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { - idx.OrderByConsumed = false - idx.EstimatedCost = 1 - idx.EstimatedRows = 1 + idx.EstimatedCost = 1000 var argvIndex = 1 var needIndex bool var listIndex []int + outputs := t.stmt.ColumnCount() for i, cst := range idx.Constraint { // Skip if this is a constraint on one of our output columns. - if cst.Column < t.outputs { + if cst.Column < outputs { continue } @@ -114,7 +112,7 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { // The non-zero argvIdx values must be contiguous. // If they're not, build a list and serialize it through IdxStr. - nextIndex := cst.Column - t.outputs + 1 + nextIndex := cst.Column - outputs + 1 idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ ArgvIndex: argvIndex, Omit: true, @@ -136,10 +134,15 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { return nil } -func (t *table) Open() (sqlite3.VTabCursor, error) { - stmt, _, err := t.db.Prepare(t.sql) - if err != nil { - return nil, err +func (t *table) Open() (_ sqlite3.VTabCursor, err error) { + stmt := t.stmt + if !t.inuse { + t.inuse = true + } else { + stmt, _, err = t.stmt.Conn().Prepare(t.sql) + if err != nil { + return nil, err + } } return &cursor{table: t, stmt: stmt}, nil } @@ -153,26 +156,29 @@ type cursor struct { stmt *sqlite3.Stmt arg []sqlite3.Value rowID int64 - done bool } func (c *cursor) Close() error { + if c.stmt == c.table.stmt { + c.table.inuse = false + c.stmt.ClearBindings() + return c.stmt.Reset() + } return c.stmt.Close() } func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { c.arg = arg c.rowID = 0 - if err := c.stmt.ClearBindings(); err != nil { - return err - } + c.stmt.ClearBindings() if err := c.stmt.Reset(); err != nil { return err } var list []int if idxStr != "" { - err := json.Unmarshal([]byte(idxStr), &list) + buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr)) + err := json.Unmarshal(buf, &list) if err != nil { return err } @@ -196,12 +202,11 @@ func (c *cursor) Next() error { c.rowID++ return nil } - c.done = true return c.stmt.Err() } func (c *cursor) EOF() bool { - return c.done + return !c.stmt.Busy() } func (c *cursor) RowID() (int64, error) { @@ -209,10 +214,11 @@ func (c *cursor) RowID() (int64, error) { } func (c *cursor) Column(ctx *sqlite3.Context, col int) error { - if col < c.table.outputs { + switch outputs := c.stmt.ColumnCount(); { + case col < outputs: ctx.ResultValue(c.stmt.ColumnValue(col)) - } else if col-c.table.outputs < len(c.arg) { - ctx.ResultValue(c.arg[col-c.table.outputs]) + case col-outputs < len(c.arg): + ctx.ResultValue(c.arg[col-outputs]) } return nil } diff --git a/func.go b/func.go index 2a8b5e4..fbf0812 100644 --- a/func.go +++ b/func.go @@ -32,7 +32,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { // CreateFunction defines a new scalar SQL function. // // https://sqlite.org/c3ref/create_function.html -func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error { +func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { defer c.arena.mark()() namePtr := c.arena.string(name) funcPtr := util.AddHandle(c.ctx, fn) @@ -42,6 +42,9 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func( return c.error(r) } +// ScalarFunction is the type of a scalar SQL function. +type ScalarFunction func(ctx Context, arg ...Value) + // CreateWindowFunction defines a new aggregate or aggregate window SQL function. // If fn returns a [WindowFunction], then an aggregate window function is created. // If fn returns an [io.Closer], it will be called to free resources. @@ -95,7 +98,7 @@ func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { db := ctx.Value(connKey{}).(*Conn) - fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value)) + fn := userDataHandle(db, pCtx).(ScalarFunction) fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...) } diff --git a/stmt.go b/stmt.go index 3f5da0d..85ac447 100644 --- a/stmt.go +++ b/stmt.go @@ -34,6 +34,22 @@ func (s *Stmt) Close() error { return s.c.error(r) } +// Conn returns the database connection to which the prepared statement belongs. +// +// https://sqlite.org/c3ref/db_handle.html +func (s *Stmt) Conn() *Conn { + return s.c +} + +// ReadOnly returns true if and only if the statement +// makes no direct changes to the content of the database file. +// +// https://sqlite.org/c3ref/stmt_readonly.html +func (s *Stmt) ReadOnly() bool { + r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle)) + return r != 0 +} + // Reset resets the prepared statement object. // // https://sqlite.org/c3ref/reset.html @@ -43,12 +59,12 @@ func (s *Stmt) Reset() error { return s.c.error(r) } -// ClearBindings resets all bindings on the prepared statement. +// Busy determines if a prepared statement has been reset. // -// https://sqlite.org/c3ref/clear_bindings.html -func (s *Stmt) ClearBindings() error { - r := s.c.call("sqlite3_clear_bindings", uint64(s.handle)) - return s.c.error(r) +// https://sqlite.org/c3ref/stmt_busy.html +func (s *Stmt) Busy() bool { + r := s.c.call("sqlite3_stmt_busy", uint64(s.handle)) + return r != 0 } // Step evaluates the SQL statement. @@ -90,13 +106,25 @@ func (s *Stmt) Exec() error { return s.Reset() } -// ReadOnly returns true if and only if the statement -// makes no direct changes to the content of the database file. +// Status monitors the performance characteristics of prepared statements. // -// https://sqlite.org/c3ref/stmt_readonly.html -func (s *Stmt) ReadOnly() bool { - r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle)) - return r != 0 +// https://sqlite.org/c3ref/stmt_status.html +func (s *Stmt) Status(op StmtStatus, reset bool) int { + var i uint64 + if reset { + i = 1 + } + r := s.c.call("sqlite3_stmt_status", uint64(s.handle), + uint64(op), i) + return int(r) +} + +// ClearBindings resets all bindings on the prepared statement. +// +// https://sqlite.org/c3ref/clear_bindings.html +func (s *Stmt) ClearBindings() error { + r := s.c.call("sqlite3_clear_bindings", uint64(s.handle)) + return s.c.error(r) } // BindCount returns the number of SQL parameters in the prepared statement. diff --git a/tests/stmt_test.go b/tests/stmt_test.go index c475f23..a39acc7 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -586,6 +586,10 @@ func TestStmt_ColumnTime(t *testing.T) { t.Errorf("want error") } } + + if got := stmt.Status(sqlite3.STMTSTATUS_RUN, true); got != 1 { + t.Errorf("got %d, want 1", got) + } } func TestStmt_Error(t *testing.T) { diff --git a/vfs/os_std_access.go b/vfs/os_std_access.go index b1ca611..1621c09 100644 --- a/vfs/os_std_access.go +++ b/vfs/os_std_access.go @@ -7,12 +7,6 @@ import ( "os" ) -const ( - _S_IREAD = 0400 - _S_IWRITE = 0200 - _S_IEXEC = 0100 -) - func osAccess(path string, flags AccessFlag) error { fi, err := os.Stat(path) if err != nil { @@ -22,12 +16,18 @@ func osAccess(path string, flags AccessFlag) error { return nil } - var want fs.FileMode = _S_IREAD + const ( + S_IREAD = 0400 + S_IWRITE = 0200 + S_IEXEC = 0100 + ) + + var want fs.FileMode = S_IREAD if flags == ACCESS_READWRITE { - want |= _S_IWRITE + want |= S_IWRITE } if fi.IsDir() { - want |= _S_IEXEC + want |= S_IEXEC } if fi.Mode()&want != want { return fs.ErrPermission diff --git a/vtab.go b/vtab.go index 928a6d9..f5bddf7 100644 --- a/vtab.go +++ b/vtab.go @@ -143,7 +143,7 @@ type VTabRenamer interface { type VTabOverloader interface { VTab // https://sqlite.org/vtab.html#xfindfunction - FindFunction(arg int, name string) (func(ctx Context, arg ...Value), IndexConstraintOp) + FindFunction(arg int, name string) (ScalarFunction, IndexConstraintOp) } // A VTabChecker allows a virtual table to report errors @@ -161,6 +161,11 @@ type VTabChecker interface { // A VTabTx allows a virtual table to implement // transactions with two-phase commit. +// +// Anything that is required as part of a commit that may fail +// should be performed in the Sync() callback. +// Current versions of SQLite ignore any errors +// returned by Commit() and Rollback(). type VTabTx interface { VTab // https://sqlite.org/vtab.html#xBegin