From fec1f8d32ae57bd39b25b608cbcef5abc2dff631 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 1 Jul 2023 00:15:28 +0100 Subject: [PATCH] Custom scalar functions. --- README.md | 2 +- const.go | 12 ++++++++++++ context.go | 41 ++++++++++++++++++++------------------- func.go | 38 +++++++++++++++++++++++++++++++----- func_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++ module.go | 9 ++++++++- value.go | 54 +++++++++++++++++++++++---------------------------- 7 files changed, 154 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 819fb9d..e291fd3 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Performance is tested by running ### Roadmap - [ ] advanced SQLite features + - [x] custom functions - [x] nested transactions - [x] incremental BLOB I/O - [x] online backup @@ -72,7 +73,6 @@ Performance is tested by running - [x] in-memory VFS - [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt) - [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki) -- [ ] custom SQL functions ### Alternatives diff --git a/const.go b/const.go index a1d6145..778674c 100644 --- a/const.go +++ b/const.go @@ -167,6 +167,18 @@ const ( PREPARE_NO_VTAB PrepareFlag = 0x04 ) +// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags]. +// +// https://www.sqlite.org/c3ref/c_deterministic.html +type FunctionFlag uint32 + +const ( + DETERMINISTIC FunctionFlag = 0x000000800 + DIRECTONLY FunctionFlag = 0x000080000 + SUBTYPE FunctionFlag = 0x000100000 + INNOCUOUS FunctionFlag = 0x000200000 +) + // Datatype is a fundamental datatype of SQLite. // // https://www.sqlite.org/c3ref/c_blob.html diff --git a/context.go b/context.go index 7e5a974..45e1b32 100644 --- a/context.go +++ b/context.go @@ -12,7 +12,7 @@ import ( // // https://www.sqlite.org/c3ref/context.html type Context struct { - c *Conn + *module handle uint32 } @@ -40,7 +40,7 @@ func (c *Context) ResultInt(value int) { // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultInt64(value int64) { - c.c.call(c.c.api.resultInteger, + c.call(c.api.resultInteger, uint64(c.handle), uint64(value)) } @@ -48,7 +48,7 @@ func (c *Context) ResultInt64(value int64) { // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultFloat(value float64) { - c.c.call(c.c.api.resultFloat, + c.call(c.api.resultFloat, uint64(c.handle), math.Float64bits(value)) } @@ -56,10 +56,10 @@ func (c *Context) ResultFloat(value float64) { // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultText(value string) { - ptr := c.c.newString(value) - c.c.call(c.c.api.resultText, + ptr := c.newString(value) + c.call(c.api.resultText, uint64(c.handle), uint64(ptr), uint64(len(value)), - uint64(c.c.api.destructor), _UTF8) + uint64(c.api.destructor), _UTF8) } // ResultBlob sets the result of the function to a []byte. @@ -67,17 +67,17 @@ func (c *Context) ResultText(value string) { // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultBlob(value []byte) { - ptr := c.c.newBytes(value) - c.c.call(c.c.api.resultBlob, + ptr := c.newBytes(value) + c.call(c.api.resultBlob, uint64(c.handle), uint64(ptr), uint64(len(value)), - uint64(c.c.api.destructor)) + uint64(c.api.destructor)) } // BindZeroBlob sets the result of the function to a zero-filled, length n BLOB. // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultZeroBlob(n int64) { - c.c.call(c.c.api.resultZeroBlob, + c.call(c.api.resultZeroBlob, uint64(c.handle), uint64(n)) } @@ -85,7 +85,7 @@ func (c *Context) ResultZeroBlob(n int64) { // // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultNull() { - c.c.call(c.c.api.resultNull, + c.call(c.api.resultNull, uint64(c.handle)) } @@ -112,13 +112,13 @@ func (c *Context) ResultTime(value time.Time, format TimeFormat) { func (c *Context) resultRFC3339Nano(value time.Time) { const maxlen = uint64(len(time.RFC3339Nano)) - ptr := c.c.new(maxlen) - buf := util.View(c.c.mod, ptr, maxlen) + ptr := c.new(maxlen) + buf := util.View(c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) - c.c.call(c.c.api.resultText, + c.call(c.api.resultText, uint64(c.handle), uint64(ptr), uint64(len(buf)), - uint64(c.c.api.destructor), _UTF8) + uint64(c.api.destructor), _UTF8) } // ResultError sets the result of the function an error. @@ -126,19 +126,20 @@ func (c *Context) resultRFC3339Nano(value time.Time) { // https://www.sqlite.org/c3ref/result_blob.html func (c *Context) ResultError(err error) { if errors.Is(err, NOMEM) { - c.c.call(c.c.api.resultErrorMem, uint64(c.handle)) + c.call(c.api.resultErrorMem, uint64(c.handle)) return } if errors.Is(err, TOOBIG) { - c.c.call(c.c.api.resultErrorBig, uint64(c.handle)) + c.call(c.api.resultErrorBig, uint64(c.handle)) return } str := err.Error() - ptr := c.c.arena.string(str) - c.c.call(c.c.api.resultBlob, + ptr := c.newString(str) + c.call(c.api.resultBlob, uint64(c.handle), uint64(ptr), uint64(len(str))) + c.free(ptr) var code uint64 var ecode ErrorCode @@ -150,7 +151,7 @@ func (c *Context) ResultError(err error) { code = uint64(ecode) } if code != 0 { - c.c.call(c.c.api.resultErrorCode, + c.call(c.api.resultErrorCode, uint64(c.handle), uint64(xcode)) } } diff --git a/func.go b/func.go index e7045a0..00b1369 100644 --- a/func.go +++ b/func.go @@ -23,6 +23,18 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { return nil } +// CreateFunction defines a new scalar function. +// +// https://www.sqlite.org/c3ref/create_function.html +func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error { + namePtr := c.arena.string(name) + funcPtr := util.AddHandle(c.ctx, fn) + r := c.call(c.api.createFunction, + uint64(c.handle), uint64(namePtr), uint64(nArg), + uint64(flag), uint64(funcPtr)) + return c.error(r) +} + func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncVI(env, "go_destroy", cbDestroy) util.ExportFuncIIIIII(env, "go_compare", cbCompare) @@ -34,16 +46,32 @@ func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder return env } -func cbDestroy(ctx context.Context, mod api.Module, pArg uint32) { - util.DelHandle(ctx, pArg) +func cbDestroy(ctx context.Context, mod api.Module, pApp uint32) { + util.DelHandle(ctx, pApp) } -func cbCompare(ctx context.Context, mod api.Module, pArg, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { - fn := util.GetHandle(ctx, pArg).(func(a, b []byte) int) +func cbCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { + fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) } -func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} +func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { + module := ctx.Value(moduleKey{}).(*module) + pApp := uint32(module.call(module.api.userData, uint64(pCtx))) + args := make([]Value, nArg) + for i := range args { + args[i] = Value{ + handle: util.ReadUint32(mod, pArg+ptrlen*uint32(i)), + module: module, + } + } + context := Context{ + handle: pCtx, + module: module, + } + fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value)) + fn(context, args...) +} func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} diff --git a/func_test.go b/func_test.go index a39fd63..16b35d5 100644 --- a/func_test.go +++ b/func_test.go @@ -1,6 +1,7 @@ package sqlite3_test import ( + "bytes" "fmt" "log" @@ -62,3 +63,57 @@ func ExampleConn_CreateCollation() { // cotée // coter } + +func ExampleConn_CreateFunction() { + db, err := sqlite3.Open(memory) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`) + if err != nil { + log.Fatal(err) + } + + err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`) + if err != nil { + log.Fatal(err) + } + + err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) { + ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob())) + }) + if err != nil { + log.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + for stmt.Step() { + fmt.Println(stmt.ColumnText(0)) + } + if err := stmt.Err(); err != nil { + log.Fatal(err) + } + + err = stmt.Close() + if err != nil { + log.Fatal(err) + } + + err = db.Close() + if err != nil { + log.Fatal(err) + } + // Unordered output: + // COTE + // COTÉ + // CÔTE + // CÔTÉ + // COTÉE + // COTER +} diff --git a/module.go b/module.go index 45f9660..4f253b6 100644 --- a/module.go +++ b/module.go @@ -84,10 +84,13 @@ type module struct { stack [8]uint64 } +type moduleKey struct{} + func newModule(mod api.Module) (m *module, err error) { m = new(module) - m.mod = mod m.ctx, m.closer = util.NewContext(context.Background()) + m.ctx = context.WithValue(m.ctx, moduleKey{}, m) + m.mod = mod getFun := func(name string) api.Function { f := mod.ExportedFunction(name) @@ -156,6 +159,8 @@ func newModule(mod api.Module) (m *module, err error) { lastRowid: getFun("sqlite3_last_insert_rowid"), autocommit: getFun("sqlite3_get_autocommit"), createCollation: getFun("sqlite3_create_go_collation"), + createFunction: getFun("sqlite3_create_go_function"), + userData: getFun("sqlite3_user_data"), valueType: getFun("sqlite3_value_type"), valueInteger: getFun("sqlite3_value_int64"), valueFloat: getFun("sqlite3_value_double"), @@ -368,6 +373,8 @@ type sqliteAPI struct { lastRowid api.Function autocommit api.Function createCollation api.Function + createFunction api.Function + userData api.Function valueType api.Function valueInteger api.Function valueFloat api.Function diff --git a/value.go b/value.go index db61819..c98d3a0 100644 --- a/value.go +++ b/value.go @@ -11,7 +11,7 @@ import ( // // https://www.sqlite.org/c3ref/value.html type Value struct { - c *Conn + *module handle uint32 } @@ -19,7 +19,7 @@ type Value struct { // // https://www.sqlite.org/c3ref/value_blob.html func (v *Value) Type() Datatype { - r := v.c.call(v.c.api.valueType, uint64(v.handle)) + r := v.call(v.api.valueType, uint64(v.handle)) return Datatype(r) } @@ -47,7 +47,7 @@ func (v *Value) Int() int { // // https://www.sqlite.org/c3ref/value_blob.html func (v *Value) Int64() int64 { - r := v.c.call(v.c.api.valueInteger, uint64(v.handle)) + r := v.call(v.api.valueInteger, uint64(v.handle)) return int64(r) } @@ -55,49 +55,44 @@ func (v *Value) Int64() int64 { // // https://www.sqlite.org/c3ref/value_blob.html func (v *Value) Float() float64 { - r := v.c.call(v.c.api.valueFloat, uint64(v.handle)) + r := v.call(v.api.valueFloat, uint64(v.handle)) return math.Float64frombits(r) } // Time returns the value as a [time.Time]. // // https://www.sqlite.org/c3ref/value_blob.html -func (v *Value) Time(format TimeFormat) (time.Time, error) { - var t any - var err error +func (v *Value) Time(format TimeFormat) time.Time { + var a any switch v.Type() { case INTEGER: - t = v.Int64() + a = v.Int64() case FLOAT: - t = v.Float() + a = v.Float() case TEXT, BLOB: - t, err = v.Text() - if err != nil { - return time.Time{}, err - } + a = v.Text() case NULL: - return time.Time{}, nil + return time.Time{} default: panic(util.AssertErr()) } - return format.Decode(t) + t, _ := format.Decode(a) + return t } // Text returns the value as a string. // // https://www.sqlite.org/c3ref/value_blob.html -func (v *Value) Text() (string, error) { - r, err := v.RawText() - return string(r), err +func (v *Value) Text() string { + return string(v.RawText()) } // Blob appends to buf and returns // the value as a []byte. // // https://www.sqlite.org/c3ref/value_blob.html -func (v *Value) Blob(buf []byte) ([]byte, error) { - r, err := v.RawBlob() - return append(buf, r...), err +func (v *Value) Blob(buf []byte) []byte { + return append(buf, v.RawBlob()...) } // RawText returns the value as a []byte. @@ -105,8 +100,8 @@ func (v *Value) Blob(buf []byte) ([]byte, error) { // subsequent calls to [Value] methods. // // https://www.sqlite.org/c3ref/value_blob.html -func (v *Value) RawText() ([]byte, error) { - r := v.c.call(v.c.api.valueText, uint64(v.handle)) +func (v *Value) RawText() []byte { + r := v.call(v.api.valueText, uint64(v.handle)) return v.rawBytes(uint32(r)) } @@ -115,17 +110,16 @@ func (v *Value) RawText() ([]byte, error) { // subsequent calls to [Value] methods. // // https://www.sqlite.org/c3ref/value_blob.html -func (v *Value) RawBlob() ([]byte, error) { - r := v.c.call(v.c.api.valueBlob, uint64(v.handle)) +func (v *Value) RawBlob() []byte { + r := v.call(v.api.valueBlob, uint64(v.handle)) return v.rawBytes(uint32(r)) } -func (v *Value) rawBytes(ptr uint32) ([]byte, error) { +func (v *Value) rawBytes(ptr uint32) []byte { if ptr == 0 { - r := v.c.call(v.c.api.errcode, uint64(v.c.handle)) - return nil, v.c.error(r) + return nil } - r := v.c.call(v.c.api.valueBytes, uint64(v.handle)) - return util.View(v.c.mod, ptr, r), nil + r := v.call(v.api.valueBytes, uint64(v.handle)) + return util.View(v.mod, ptr, r) }