mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Pointer-passing interfaces.
This commit is contained in:
10
context.go
10
context.go
@@ -148,6 +148,16 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
|
||||
uint64(ctx.c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
|
||||
// except that it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (ctx Context) ResultPointer(ptr any) {
|
||||
valPtr := util.AddHandle(ctx.c.ctx, ptr)
|
||||
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
|
||||
}
|
||||
|
||||
// ResultJSON sets the result of the function to the JSON encoding of value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
|
||||
@@ -380,6 +380,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
err = s.Stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case interface{ Value() any }:
|
||||
err = s.Stmt.BindPointer(id, a.Value())
|
||||
case time.Time:
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case json.Marshaler:
|
||||
@@ -400,7 +402,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, json.Marshaler, nil:
|
||||
sqlite3.ZeroBlob, interface{ Value() any },
|
||||
time.Time, json.Marshaler, nil:
|
||||
return nil
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
|
||||
@@ -25,6 +25,7 @@ sqlite3_bind_double
|
||||
sqlite3_bind_text64
|
||||
sqlite3_bind_blob64
|
||||
sqlite3_bind_zeroblob64
|
||||
sqlite3_bind_pointer_go
|
||||
sqlite3_column_count
|
||||
sqlite3_column_name
|
||||
sqlite3_column_type
|
||||
@@ -64,12 +65,14 @@ sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_value_pointer_go
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_pointer_go
|
||||
sqlite3_result_value
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
|
||||
Binary file not shown.
59
ext/blob/blob.go
Normal file
59
ext/blob/blob.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package blob provides an alternative interface to incremental BLOB I/O.
|
||||
package blob
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers the blob_open SQL function.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
db.CreateFunction("blob_open", -1,
|
||||
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
|
||||
}
|
||||
|
||||
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if len(arg) < 6 {
|
||||
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
|
||||
return
|
||||
}
|
||||
|
||||
row := arg[3].Int64()
|
||||
|
||||
var err error
|
||||
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
|
||||
if ok {
|
||||
err = blob.Reopen(row)
|
||||
if errors.Is(err, sqlite3.MISUSE) {
|
||||
// Blob was closed (db, table or column changed).
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
db := arg[0].Text()
|
||||
table := arg[1].Text()
|
||||
column := arg[2].Text()
|
||||
write := arg[4].Bool()
|
||||
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
|
||||
}
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
fn := arg[5].Pointer().(OpenCallback)
|
||||
err = fn(blob, arg[6:]...)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// This ensures the blob is closed if db, table or column change.
|
||||
ctx.SetAuxData(0, blob)
|
||||
ctx.SetAuxData(1, blob)
|
||||
ctx.SetAuxData(2, blob)
|
||||
}
|
||||
|
||||
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
|
||||
61
ext/blob/blob_test.go
Normal file
61
ext/blob/blob_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package blob_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/blob"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
// Open the database, registering the extension.
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
blob.Register(conn)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove("demo.db")
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
const message = "Hello BLOB!"
|
||||
|
||||
// Create the BLOB.
|
||||
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Write the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.WriteString(blob, message)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the BLOB.
|
||||
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
|
||||
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
|
||||
_, err = io.Copy(os.Stdout, blob)
|
||||
return err
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// Hello BLOB!
|
||||
}
|
||||
14
pointer.go
Normal file
14
pointer.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package sqlite3
|
||||
|
||||
// Pointer returns a pointer to a value
|
||||
// that can be used as an argument to
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
//
|
||||
// https://www.sqlite.org/bindptr.html
|
||||
func Pointer[T any](val T) any {
|
||||
return pointer[T]{val}
|
||||
}
|
||||
|
||||
type pointer[T any] struct{ val T }
|
||||
|
||||
func (p pointer[T]) Value() any { return p.val }
|
||||
@@ -132,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
bindPointer: getFun("sqlite3_bind_pointer_go"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
@@ -169,12 +170,14 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
valuePointer: getFun("sqlite3_value_pointer_go"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultPointer: getFun("sqlite3_result_pointer_go"),
|
||||
resultValue: getFun("sqlite3_result_value"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
@@ -353,6 +356,7 @@ type sqliteAPI struct {
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
bindPointer api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
@@ -390,12 +394,14 @@ type sqliteAPI struct {
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
valuePointer api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultPointer api.Function
|
||||
resultValue api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
|
||||
@@ -38,4 +38,18 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
}
|
||||
|
||||
#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"
|
||||
|
||||
int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
|
||||
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
|
||||
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
|
||||
}
|
||||
|
||||
void *sqlite3_value_pointer_go(sqlite3_value *val) {
|
||||
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
|
||||
}
|
||||
|
||||
12
stmt.go
12
stmt.go
@@ -250,6 +250,18 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
|
||||
// but it also associates ptr with that NULL value such that it can be retrieved
|
||||
// within an application-defined SQL function using [Value.Pointer].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindPointer(param int, ptr any) error {
|
||||
valPtr := util.AddHandle(s.c.ctx, ptr)
|
||||
r := s.c.call(s.c.api.bindPointer,
|
||||
uint64(s.handle), uint64(param), uint64(valPtr))
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindJSON binds the JSON encoding of value to the prepared statement.
|
||||
// The leftmost SQL parameter has an index of 1.
|
||||
//
|
||||
|
||||
7
value.go
7
value.go
@@ -126,6 +126,13 @@ func (v Value) rawBytes(ptr uint32) []byte {
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
|
||||
// Pointer gets the pointer associated with this value,
|
||||
// or nil if it has no associated pointer.
|
||||
func (v Value) Pointer() any {
|
||||
r := v.call(v.api.valuePointer, uint64(v.handle))
|
||||
return util.GetHandle(v.ctx, uint32(r))
|
||||
}
|
||||
|
||||
// JSON parses a JSON-encoded value
|
||||
// and stores the result in the value pointed to by ptr.
|
||||
func (v Value) JSON(ptr any) error {
|
||||
|
||||
Reference in New Issue
Block a user