Pointer-passing interfaces.

This commit is contained in:
Nuno Cruces
2023-11-07 00:50:43 +00:00
parent 24b965ac7e
commit 24c9b57c56
11 changed files with 191 additions and 2 deletions

View File

@@ -148,6 +148,16 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
uint64(ctx.c.api.destructor), _UTF8)
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
// except that it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html

View File

@@ -380,6 +380,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case interface{ Value() any }:
err = s.Stmt.BindPointer(id, a.Value())
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
@@ -400,7 +402,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, json.Marshaler, nil:
sqlite3.ZeroBlob, interface{ Value() any },
time.Time, json.Marshaler, nil:
return nil
default:
return driver.ErrSkip

View File

@@ -25,6 +25,7 @@ sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_bind_pointer_go
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
@@ -64,12 +65,14 @@ sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_pointer_go
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_pointer_go
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code

Binary file not shown.

59
ext/blob/blob.go Normal file
View File

@@ -0,0 +1,59 @@
// Package blob provides an alternative interface to incremental BLOB I/O.
package blob
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers the blob_open SQL function.
func Register(db *sqlite3.Conn) {
db.CreateFunction("blob_open", -1,
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
}
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
return
}
row := arg[3].Int64()
var err error
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
if ok {
err = blob.Reopen(row)
if errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table or column changed).
ok = false
}
}
if !ok {
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
write := arg[4].Bool()
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
}
if err != nil {
ctx.ResultError(err)
return
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}
// This ensures the blob is closed if db, table or column change.
ctx.SetAuxData(0, blob)
ctx.SetAuxData(1, blob)
ctx.SetAuxData(2, blob)
}
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error

61
ext/blob/blob_test.go Normal file
View File

@@ -0,0 +1,61 @@
package blob_test
import (
"io"
"log"
"os"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/blob"
)
func Example() {
// Open the database, registering the extension.
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
blob.Register(conn)
return nil
})
if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}
const message = "Hello BLOB!"
// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}
// Write the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.WriteString(blob, message)
return err
}))
if err != nil {
log.Fatal(err)
}
// Read the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}

14
pointer.go Normal file
View File

@@ -0,0 +1,14 @@
package sqlite3
// Pointer returns a pointer to a value
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
//
// https://www.sqlite.org/bindptr.html
func Pointer[T any](val T) any {
return pointer[T]{val}
}
type pointer[T any] struct{ val T }
func (p pointer[T]) Value() any { return p.val }

View File

@@ -132,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindPointer: getFun("sqlite3_bind_pointer_go"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
@@ -169,12 +170,14 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
valuePointer: getFun("sqlite3_value_pointer_go"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultPointer: getFun("sqlite3_result_pointer_go"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
@@ -353,6 +356,7 @@ type sqliteAPI struct {
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindPointer api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
@@ -390,12 +394,14 @@ type sqliteAPI struct {
valueText api.Function
valueBlob api.Function
valueBytes api.Function
valuePointer api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultPointer api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function

View File

@@ -38,4 +38,18 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}
}
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
}
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
}
void *sqlite3_value_pointer_go(sqlite3_value *val) {
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
}

12
stmt.go
View File

@@ -250,6 +250,18 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
// but it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call(s.c.api.bindPointer,
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//

View File

@@ -126,6 +126,13 @@ func (v Value) rawBytes(ptr uint32) []byte {
return util.View(v.mod, ptr, r)
}
// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
r := v.call(v.api.valuePointer, uint64(v.handle))
return util.GetHandle(v.ctx, uint32(r))
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {