Reuse statement, API.

This commit is contained in:
Nuno Cruces
2023-12-04 12:37:53 +00:00
parent 8a0baedc10
commit cd40213898
12 changed files with 144 additions and 83 deletions

View File

@@ -169,7 +169,8 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
// FunctionFlag is a flag that can be passed to
// [Conn.CreateFunction] and [Conn.CreateWindowFunction].
//
// https://sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
@@ -181,6 +182,23 @@ const (
INNOCUOUS FunctionFlag = 0x000200000
)
// StmtStatus name counter values associated with the [Stmt.Status] method.
//
// https://sqlite.org/c3ref/c_stmtstatus_counter.html
type StmtStatus uint32
const (
STMTSTATUS_FULLSCAN_STEP StmtStatus = 1
STMTSTATUS_SORT StmtStatus = 2
STMTSTATUS_AUTOINDEX StmtStatus = 3
STMTSTATUS_VM_STEP StmtStatus = 4
STMTSTATUS_REPREPARE StmtStatus = 5
STMTSTATUS_RUN StmtStatus = 6
STMTSTATUS_FILTER_MISS StmtStatus = 7
STMTSTATUS_FILTER_HIT StmtStatus = 8
STMTSTATUS_MEMUSED StmtStatus = 99
)
// Datatype is a fundamental datatype of SQLite.
//
// https://sqlite.org/c3ref/c_blob.html

View File

@@ -250,7 +250,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
s.Close()
return nil, util.TailErr
}
return &stmt{s, c.Conn}, nil
return &stmt{s}, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
@@ -281,8 +281,7 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
}
type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
*sqlite3.Stmt
}
var (
@@ -292,10 +291,6 @@ var (
_ driver.NamedValueChecker = &stmt{}
)
func (s *stmt) Close() error {
return s.Stmt.Close()
}
func (s *stmt) NumInput() int {
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
@@ -322,15 +317,15 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, err
}
old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)
err = s.Stmt.Exec()
if err != nil {
return nil, err
}
return newResult(s.Conn), nil
return newResult(s.Stmt.Conn()), nil
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -338,7 +333,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt, s.Conn}, nil
return &rows{ctx, s.Stmt}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) error {
@@ -442,10 +437,10 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type rows struct {
ctx context.Context
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}
func (r *rows) Close() error {
r.Stmt.ClearBindings()
return r.Stmt.Reset()
}
@@ -469,8 +464,8 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
}
func (r *rows) Next(dest []driver.Value) error {
old := r.Conn.SetInterrupt(r.ctx)
defer r.Conn.SetInterrupt(old)
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {

View File

@@ -74,7 +74,9 @@ sqlite3_result_value
sqlite3_result_zeroblob64
sqlite3_set_auxdata_go
sqlite3_step
sqlite3_stmt_busy
sqlite3_stmt_readonly
sqlite3_stmt_status
sqlite3_uri_key
sqlite3_uri_parameter
sqlite3_user_data

Binary file not shown.

View File

@@ -172,9 +172,9 @@ func (t *table) newReader() *csv.Reader {
type cursor struct {
table *table
rowID int64
row []string
csv *csv.Reader
row []string
rowID int64
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {

View File

@@ -52,11 +52,11 @@ func (l lines) Open() (sqlite3.VTabCursor, error) {
}
type cursor struct {
reader bool
scanner *bufio.Scanner
closer io.Closer
rowID int64
eof bool
reader bool
}
func (c *cursor) Close() (err error) {

View File

@@ -26,12 +26,10 @@ func Register(db *sqlite3.Conn) {
sql = sql[1 : len-1]
}
table := &table{
db: db,
sql: sql,
}
err = table.declare()
table := &table{sql: sql}
err = table.declare(db)
if err != nil {
table.Close()
return nil, err
}
return table, nil
@@ -41,42 +39,40 @@ func Register(db *sqlite3.Conn) {
}
type table struct {
db *sqlite3.Conn
sql string
inputs int
outputs int
stmt *sqlite3.Stmt
sql string
inuse bool
}
func (t *table) declare() error {
stmt, tail, err := t.db.Prepare(t.sql)
func (t *table) declare(db *sqlite3.Conn) (err error) {
var tail string
t.stmt, tail, err = db.Prepare(t.sql)
if err != nil {
return err
}
defer stmt.Close()
if tail != "" {
return fmt.Errorf("statement: multiple statements")
}
if !stmt.ReadOnly() {
if !t.stmt.ReadOnly() {
return fmt.Errorf("statement: statement must be read only")
}
t.inputs = stmt.BindCount()
t.outputs = stmt.ColumnCount()
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
for i := 0; i < t.outputs; i++ {
outputs := t.stmt.ColumnCount()
for i := 0; i < outputs; i++ {
str.WriteString(sep)
name := stmt.ColumnName(i)
name := t.stmt.ColumnName(i)
str.WriteString(sqlite3.QuoteIdentifier(name))
str.WriteByte(' ')
str.WriteString(stmt.ColumnDeclType(i))
str.WriteString(t.stmt.ColumnDeclType(i))
sep = ","
}
for i := 1; i <= t.inputs; i++ {
inputs := t.stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
name := t.stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
@@ -87,22 +83,24 @@ func (t *table) declare() error {
}
sep = ","
}
str.WriteByte(')')
return t.db.DeclareVtab(str.String())
return db.DeclareVtab(str.String())
}
func (t *table) Close() error {
return t.stmt.Close()
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.OrderByConsumed = false
idx.EstimatedCost = 1
idx.EstimatedRows = 1
idx.EstimatedCost = 1000
var argvIndex = 1
var needIndex bool
var listIndex []int
outputs := t.stmt.ColumnCount()
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < t.outputs {
if cst.Column < outputs {
continue
}
@@ -114,7 +112,7 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - t.outputs + 1
nextIndex := cst.Column - outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
@@ -136,10 +134,15 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt, _, err := t.db.Prepare(t.sql)
if err != nil {
return nil, err
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
}
return &cursor{table: t, stmt: stmt}, nil
}
@@ -153,26 +156,29 @@ type cursor struct {
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
done bool
}
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
c.stmt.ClearBindings()
return c.stmt.Reset()
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
if err := c.stmt.ClearBindings(); err != nil {
return err
}
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
return err
}
var list []int
if idxStr != "" {
err := json.Unmarshal([]byte(idxStr), &list)
buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
err := json.Unmarshal(buf, &list)
if err != nil {
return err
}
@@ -196,12 +202,11 @@ func (c *cursor) Next() error {
c.rowID++
return nil
}
c.done = true
return c.stmt.Err()
}
func (c *cursor) EOF() bool {
return c.done
return !c.stmt.Busy()
}
func (c *cursor) RowID() (int64, error) {
@@ -209,10 +214,11 @@ func (c *cursor) RowID() (int64, error) {
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
if col < c.table.outputs {
switch outputs := c.stmt.ColumnCount(); {
case col < outputs:
ctx.ResultValue(c.stmt.ColumnValue(col))
} else if col-c.table.outputs < len(c.arg) {
ctx.ResultValue(c.arg[col-c.table.outputs])
case col-outputs < len(c.arg):
ctx.ResultValue(c.arg[col-outputs])
}
return nil
}

View File

@@ -32,7 +32,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
@@ -42,6 +42,9 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
return c.error(r)
}
// ScalarFunction is the type of a scalar SQL function.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
@@ -95,7 +98,7 @@ func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn := userDataHandle(db, pCtx).(ScalarFunction)
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}

50
stmt.go
View File

@@ -34,6 +34,22 @@ func (s *Stmt) Close() error {
return s.c.error(r)
}
// Conn returns the database connection to which the prepared statement belongs.
//
// https://sqlite.org/c3ref/db_handle.html
func (s *Stmt) Conn() *Conn {
return s.c
}
// ReadOnly returns true if and only if the statement
// makes no direct changes to the content of the database file.
//
// https://sqlite.org/c3ref/stmt_readonly.html
func (s *Stmt) ReadOnly() bool {
r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle))
return r != 0
}
// Reset resets the prepared statement object.
//
// https://sqlite.org/c3ref/reset.html
@@ -43,12 +59,12 @@ func (s *Stmt) Reset() error {
return s.c.error(r)
}
// ClearBindings resets all bindings on the prepared statement.
// Busy determines if a prepared statement has been reset.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call("sqlite3_clear_bindings", uint64(s.handle))
return s.c.error(r)
// https://sqlite.org/c3ref/stmt_busy.html
func (s *Stmt) Busy() bool {
r := s.c.call("sqlite3_stmt_busy", uint64(s.handle))
return r != 0
}
// Step evaluates the SQL statement.
@@ -90,13 +106,25 @@ func (s *Stmt) Exec() error {
return s.Reset()
}
// ReadOnly returns true if and only if the statement
// makes no direct changes to the content of the database file.
// Status monitors the performance characteristics of prepared statements.
//
// https://sqlite.org/c3ref/stmt_readonly.html
func (s *Stmt) ReadOnly() bool {
r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle))
return r != 0
// https://sqlite.org/c3ref/stmt_status.html
func (s *Stmt) Status(op StmtStatus, reset bool) int {
var i uint64
if reset {
i = 1
}
r := s.c.call("sqlite3_stmt_status", uint64(s.handle),
uint64(op), i)
return int(r)
}
// ClearBindings resets all bindings on the prepared statement.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call("sqlite3_clear_bindings", uint64(s.handle))
return s.c.error(r)
}
// BindCount returns the number of SQL parameters in the prepared statement.

View File

@@ -586,6 +586,10 @@ func TestStmt_ColumnTime(t *testing.T) {
t.Errorf("want error")
}
}
if got := stmt.Status(sqlite3.STMTSTATUS_RUN, true); got != 1 {
t.Errorf("got %d, want 1", got)
}
}
func TestStmt_Error(t *testing.T) {

View File

@@ -7,12 +7,6 @@ import (
"os"
)
const (
_S_IREAD = 0400
_S_IWRITE = 0200
_S_IEXEC = 0100
)
func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
@@ -22,12 +16,18 @@ func osAccess(path string, flags AccessFlag) error {
return nil
}
var want fs.FileMode = _S_IREAD
const (
S_IREAD = 0400
S_IWRITE = 0200
S_IEXEC = 0100
)
var want fs.FileMode = S_IREAD
if flags == ACCESS_READWRITE {
want |= _S_IWRITE
want |= S_IWRITE
}
if fi.IsDir() {
want |= _S_IEXEC
want |= S_IEXEC
}
if fi.Mode()&want != want {
return fs.ErrPermission

View File

@@ -143,7 +143,7 @@ type VTabRenamer interface {
type VTabOverloader interface {
VTab
// https://sqlite.org/vtab.html#xfindfunction
FindFunction(arg int, name string) (func(ctx Context, arg ...Value), IndexConstraintOp)
FindFunction(arg int, name string) (ScalarFunction, IndexConstraintOp)
}
// A VTabChecker allows a virtual table to report errors
@@ -161,6 +161,11 @@ type VTabChecker interface {
// A VTabTx allows a virtual table to implement
// transactions with two-phase commit.
//
// Anything that is required as part of a commit that may fail
// should be performed in the Sync() callback.
// Current versions of SQLite ignore any errors
// returned by Commit() and Rollback().
type VTabTx interface {
VTab
// https://sqlite.org/vtab.html#xBegin