diff --git a/context.go b/context.go index b2e1a98..b0b5563 100644 --- a/context.go +++ b/context.go @@ -148,6 +148,16 @@ func (ctx Context) resultRFC3339Nano(value time.Time) { uint64(ctx.c.api.destructor), _UTF8) } +// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull], +// except that it also associates ptr with that NULL value such that it can be retrieved +// within an application-defined SQL function using [Value.Pointer]. +// +// https://www.sqlite.org/c3ref/result_blob.html +func (ctx Context) ResultPointer(ptr any) { + valPtr := util.AddHandle(ctx.c.ctx, ptr) + ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr)) +} + // ResultJSON sets the result of the function to the JSON encoding of value. // // https://www.sqlite.org/c3ref/result_blob.html diff --git a/driver/driver.go b/driver/driver.go index 2b733f0..33eb740 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -380,6 +380,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { err = s.Stmt.BindBlob(id, a) case sqlite3.ZeroBlob: err = s.Stmt.BindZeroBlob(id, int64(a)) + case interface{ Value() any }: + err = s.Stmt.BindPointer(id, a.Value()) case time.Time: err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault) case json.Marshaler: @@ -400,7 +402,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case bool, int, int64, float64, string, []byte, - sqlite3.ZeroBlob, time.Time, json.Marshaler, nil: + sqlite3.ZeroBlob, interface{ Value() any }, + time.Time, json.Marshaler, nil: return nil default: return driver.ErrSkip diff --git a/embed/exports.txt b/embed/exports.txt index b96d40a..51a4d1c 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -25,6 +25,7 @@ sqlite3_bind_double sqlite3_bind_text64 sqlite3_bind_blob64 sqlite3_bind_zeroblob64 +sqlite3_bind_pointer_go sqlite3_column_count sqlite3_column_name sqlite3_column_type @@ -64,12 +65,14 @@ sqlite3_value_double sqlite3_value_text sqlite3_value_blob sqlite3_value_bytes +sqlite3_value_pointer_go sqlite3_result_null sqlite3_result_int64 sqlite3_result_double sqlite3_result_text64 sqlite3_result_blob64 sqlite3_result_zeroblob64 +sqlite3_result_pointer_go sqlite3_result_value sqlite3_result_error sqlite3_result_error_code diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 6901953..3e3ea0e 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/ext/blob/blob.go b/ext/blob/blob.go new file mode 100644 index 0000000..e9a6e79 --- /dev/null +++ b/ext/blob/blob.go @@ -0,0 +1,59 @@ +// Package blob provides an alternative interface to incremental BLOB I/O. +package blob + +import ( + "errors" + + "github.com/ncruces/go-sqlite3" +) + +// Register registers the blob_open SQL function. +func Register(db *sqlite3.Conn) { + db.CreateFunction("blob_open", -1, + sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob) +} + +func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) { + if len(arg) < 6 { + ctx.ResultError(errors.New("wrong number of arguments to function blob_open()")) + return + } + + row := arg[3].Int64() + + var err error + blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob) + if ok { + err = blob.Reopen(row) + if errors.Is(err, sqlite3.MISUSE) { + // Blob was closed (db, table or column changed). + ok = false + } + } + + if !ok { + db := arg[0].Text() + table := arg[1].Text() + column := arg[2].Text() + write := arg[4].Bool() + blob, err = ctx.Conn().OpenBlob(db, table, column, row, write) + } + if err != nil { + ctx.ResultError(err) + return + } + + fn := arg[5].Pointer().(OpenCallback) + err = fn(blob, arg[6:]...) + if err != nil { + ctx.ResultError(err) + return + } + + // This ensures the blob is closed if db, table or column change. + ctx.SetAuxData(0, blob) + ctx.SetAuxData(1, blob) + ctx.SetAuxData(2, blob) +} + +type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error diff --git a/ext/blob/blob_test.go b/ext/blob/blob_test.go new file mode 100644 index 0000000..b2265f2 --- /dev/null +++ b/ext/blob/blob_test.go @@ -0,0 +1,61 @@ +package blob_test + +import ( + "io" + "log" + "os" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/blob" +) + +func Example() { + // Open the database, registering the extension. + db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error { + blob.Register(conn) + return nil + }) + + if err != nil { + log.Fatal(err) + } + defer os.Remove("demo.db") + defer db.Close() + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + log.Fatal(err) + } + + const message = "Hello BLOB!" + + // Create the BLOB. + _, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message))) + if err != nil { + log.Fatal(err) + } + + // Write the BLOB. + _, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`, + sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error { + _, err = io.WriteString(blob, message) + return err + })) + if err != nil { + log.Fatal(err) + } + + // Read the BLOB. + _, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`, + sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error { + _, err = io.Copy(os.Stdout, blob) + return err + })) + if err != nil { + log.Fatal(err) + } + // Output: + // Hello BLOB! +} diff --git a/pointer.go b/pointer.go new file mode 100644 index 0000000..cf5a4d8 --- /dev/null +++ b/pointer.go @@ -0,0 +1,14 @@ +package sqlite3 + +// Pointer returns a pointer to a value +// that can be used as an argument to +// [database/sql.DB.Exec] and similar methods. +// +// https://www.sqlite.org/bindptr.html +func Pointer[T any](val T) any { + return pointer[T]{val} +} + +type pointer[T any] struct{ val T } + +func (p pointer[T]) Value() any { return p.val } diff --git a/sqlite.go b/sqlite.go index bbf900f..4dc68cd 100644 --- a/sqlite.go +++ b/sqlite.go @@ -132,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) { bindText: getFun("sqlite3_bind_text64"), bindBlob: getFun("sqlite3_bind_blob64"), bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), + bindPointer: getFun("sqlite3_bind_pointer_go"), columnCount: getFun("sqlite3_column_count"), columnName: getFun("sqlite3_column_name"), columnType: getFun("sqlite3_column_type"), @@ -169,12 +170,14 @@ func instantiateSQLite() (sqlt *sqlite, err error) { valueText: getFun("sqlite3_value_text"), valueBlob: getFun("sqlite3_value_blob"), valueBytes: getFun("sqlite3_value_bytes"), + valuePointer: getFun("sqlite3_value_pointer_go"), resultNull: getFun("sqlite3_result_null"), resultInteger: getFun("sqlite3_result_int64"), resultFloat: getFun("sqlite3_result_double"), resultText: getFun("sqlite3_result_text64"), resultBlob: getFun("sqlite3_result_blob64"), resultZeroBlob: getFun("sqlite3_result_zeroblob64"), + resultPointer: getFun("sqlite3_result_pointer_go"), resultValue: getFun("sqlite3_result_value"), resultError: getFun("sqlite3_result_error"), resultErrorCode: getFun("sqlite3_result_error_code"), @@ -353,6 +356,7 @@ type sqliteAPI struct { bindText api.Function bindBlob api.Function bindZeroBlob api.Function + bindPointer api.Function columnCount api.Function columnName api.Function columnType api.Function @@ -390,12 +394,14 @@ type sqliteAPI struct { valueText api.Function valueBlob api.Function valueBytes api.Function + valuePointer api.Function resultNull api.Function resultInteger api.Function resultFloat api.Function resultText api.Function resultBlob api.Function resultZeroBlob api.Function + resultPointer api.Function resultValue api.Function resultError api.Function resultErrorCode api.Function diff --git a/sqlite3/func.c b/sqlite3/func.c index 8451033..f7241fa 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -38,4 +38,18 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg, void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) { sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy); -} \ No newline at end of file +} + +#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer" + +int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) { + return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy); +} + +void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) { + sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy); +} + +void *sqlite3_value_pointer_go(sqlite3_value *val) { + return sqlite3_value_pointer(val, GO_POINTER_TYPE); +} diff --git a/stmt.go b/stmt.go index 04a91ec..67b44db 100644 --- a/stmt.go +++ b/stmt.go @@ -250,6 +250,18 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { return s.c.error(r) } +// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull], +// but it also associates ptr with that NULL value such that it can be retrieved +// within an application-defined SQL function using [Value.Pointer]. +// +// https://www.sqlite.org/c3ref/bind_blob.html +func (s *Stmt) BindPointer(param int, ptr any) error { + valPtr := util.AddHandle(s.c.ctx, ptr) + r := s.c.call(s.c.api.bindPointer, + uint64(s.handle), uint64(param), uint64(valPtr)) + return s.c.error(r) +} + // BindJSON binds the JSON encoding of value to the prepared statement. // The leftmost SQL parameter has an index of 1. // diff --git a/value.go b/value.go index 0fcc5ef..5da5ac8 100644 --- a/value.go +++ b/value.go @@ -126,6 +126,13 @@ func (v Value) rawBytes(ptr uint32) []byte { return util.View(v.mod, ptr, r) } +// Pointer gets the pointer associated with this value, +// or nil if it has no associated pointer. +func (v Value) Pointer() any { + r := v.call(v.api.valuePointer, uint64(v.handle)) + return util.GetHandle(v.ctx, uint32(r)) +} + // JSON parses a JSON-encoded value // and stores the result in the value pointed to by ptr. func (v Value) JSON(ptr any) error {