Custom scalar functions.

This commit is contained in:
Nuno Cruces
2023-07-01 00:15:28 +01:00
parent 31572e6095
commit fec1f8d32a
7 changed files with 154 additions and 57 deletions

View File

@@ -63,6 +63,7 @@ Performance is tested by running
### Roadmap ### Roadmap
- [ ] advanced SQLite features - [ ] advanced SQLite features
- [x] custom functions
- [x] nested transactions - [x] nested transactions
- [x] incremental BLOB I/O - [x] incremental BLOB I/O
- [x] online backup - [x] online backup
@@ -72,7 +73,6 @@ Performance is tested by running
- [x] in-memory VFS - [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt) - [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki) - [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives ### Alternatives

View File

@@ -167,6 +167,18 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04 PREPARE_NO_VTAB PrepareFlag = 0x04
) )
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
)
// Datatype is a fundamental datatype of SQLite. // Datatype is a fundamental datatype of SQLite.
// //
// https://www.sqlite.org/c3ref/c_blob.html // https://www.sqlite.org/c3ref/c_blob.html

View File

@@ -12,7 +12,7 @@ import (
// //
// https://www.sqlite.org/c3ref/context.html // https://www.sqlite.org/c3ref/context.html
type Context struct { type Context struct {
c *Conn *module
handle uint32 handle uint32
} }
@@ -40,7 +40,7 @@ func (c *Context) ResultInt(value int) {
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultInt64(value int64) { func (c *Context) ResultInt64(value int64) {
c.c.call(c.c.api.resultInteger, c.call(c.api.resultInteger,
uint64(c.handle), uint64(value)) uint64(c.handle), uint64(value))
} }
@@ -48,7 +48,7 @@ func (c *Context) ResultInt64(value int64) {
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultFloat(value float64) { func (c *Context) ResultFloat(value float64) {
c.c.call(c.c.api.resultFloat, c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value)) uint64(c.handle), math.Float64bits(value))
} }
@@ -56,10 +56,10 @@ func (c *Context) ResultFloat(value float64) {
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultText(value string) { func (c *Context) ResultText(value string) {
ptr := c.c.newString(value) ptr := c.newString(value)
c.c.call(c.c.api.resultText, c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)), uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.c.api.destructor), _UTF8) uint64(c.api.destructor), _UTF8)
} }
// ResultBlob sets the result of the function to a []byte. // ResultBlob sets the result of the function to a []byte.
@@ -67,17 +67,17 @@ func (c *Context) ResultText(value string) {
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultBlob(value []byte) { func (c *Context) ResultBlob(value []byte) {
ptr := c.c.newBytes(value) ptr := c.newBytes(value)
c.c.call(c.c.api.resultBlob, c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)), uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.c.api.destructor)) uint64(c.api.destructor))
} }
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB. // BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultZeroBlob(n int64) { func (c *Context) ResultZeroBlob(n int64) {
c.c.call(c.c.api.resultZeroBlob, c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n)) uint64(c.handle), uint64(n))
} }
@@ -85,7 +85,7 @@ func (c *Context) ResultZeroBlob(n int64) {
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultNull() { func (c *Context) ResultNull() {
c.c.call(c.c.api.resultNull, c.call(c.api.resultNull,
uint64(c.handle)) uint64(c.handle))
} }
@@ -112,13 +112,13 @@ func (c *Context) ResultTime(value time.Time, format TimeFormat) {
func (c *Context) resultRFC3339Nano(value time.Time) { func (c *Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano)) const maxlen = uint64(len(time.RFC3339Nano))
ptr := c.c.new(maxlen) ptr := c.new(maxlen)
buf := util.View(c.c.mod, ptr, maxlen) buf := util.View(c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano) buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.c.call(c.c.api.resultText, c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)), uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.c.api.destructor), _UTF8) uint64(c.api.destructor), _UTF8)
} }
// ResultError sets the result of the function an error. // ResultError sets the result of the function an error.
@@ -126,19 +126,20 @@ func (c *Context) resultRFC3339Nano(value time.Time) {
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultError(err error) { func (c *Context) ResultError(err error) {
if errors.Is(err, NOMEM) { if errors.Is(err, NOMEM) {
c.c.call(c.c.api.resultErrorMem, uint64(c.handle)) c.call(c.api.resultErrorMem, uint64(c.handle))
return return
} }
if errors.Is(err, TOOBIG) { if errors.Is(err, TOOBIG) {
c.c.call(c.c.api.resultErrorBig, uint64(c.handle)) c.call(c.api.resultErrorBig, uint64(c.handle))
return return
} }
str := err.Error() str := err.Error()
ptr := c.c.arena.string(str) ptr := c.newString(str)
c.c.call(c.c.api.resultBlob, c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(str))) uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
var code uint64 var code uint64
var ecode ErrorCode var ecode ErrorCode
@@ -150,7 +151,7 @@ func (c *Context) ResultError(err error) {
code = uint64(ecode) code = uint64(ecode)
} }
if code != 0 { if code != 0 {
c.c.call(c.c.api.resultErrorCode, c.call(c.api.resultErrorCode,
uint64(c.handle), uint64(xcode)) uint64(c.handle), uint64(xcode))
} }
} }

38
func.go
View File

@@ -23,6 +23,18 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
return nil return nil
} }
// CreateFunction defines a new scalar function.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", cbDestroy) util.ExportFuncVI(env, "go_destroy", cbDestroy)
util.ExportFuncIIIIII(env, "go_compare", cbCompare) util.ExportFuncIIIIII(env, "go_compare", cbCompare)
@@ -34,16 +46,32 @@ func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
return env return env
} }
func cbDestroy(ctx context.Context, mod api.Module, pArg uint32) { func cbDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pArg) util.DelHandle(ctx, pApp)
} }
func cbCompare(ctx context.Context, mod api.Module, pArg, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { func cbCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pArg).(func(a, b []byte) int) fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
} }
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
module := ctx.Value(moduleKey{}).(*module)
pApp := uint32(module.call(module.api.userData, uint64(pCtx)))
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
handle: util.ReadUint32(mod, pArg+ptrlen*uint32(i)),
module: module,
}
}
context := Context{
handle: pCtx,
module: module,
}
fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value))
fn(context, args...)
}
func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {} func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}

View File

@@ -1,6 +1,7 @@
package sqlite3_test package sqlite3_test
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
@@ -62,3 +63,57 @@ func ExampleConn_CreateCollation() {
// cotée // cotée
// coter // coter
} }
func ExampleConn_CreateFunction() {
db, err := sqlite3.Open(memory)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
err = stmt.Close()
if err != nil {
log.Fatal(err)
}
err = db.Close()
if err != nil {
log.Fatal(err)
}
// Unordered output:
// COTE
// COTÉ
// CÔTE
// CÔTÉ
// COTÉE
// COTER
}

View File

@@ -84,10 +84,13 @@ type module struct {
stack [8]uint64 stack [8]uint64
} }
type moduleKey struct{}
func newModule(mod api.Module) (m *module, err error) { func newModule(mod api.Module) (m *module, err error) {
m = new(module) m = new(module)
m.mod = mod
m.ctx, m.closer = util.NewContext(context.Background()) m.ctx, m.closer = util.NewContext(context.Background())
m.ctx = context.WithValue(m.ctx, moduleKey{}, m)
m.mod = mod
getFun := func(name string) api.Function { getFun := func(name string) api.Function {
f := mod.ExportedFunction(name) f := mod.ExportedFunction(name)
@@ -156,6 +159,8 @@ func newModule(mod api.Module) (m *module, err error) {
lastRowid: getFun("sqlite3_last_insert_rowid"), lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"), autocommit: getFun("sqlite3_get_autocommit"),
createCollation: getFun("sqlite3_create_go_collation"), createCollation: getFun("sqlite3_create_go_collation"),
createFunction: getFun("sqlite3_create_go_function"),
userData: getFun("sqlite3_user_data"),
valueType: getFun("sqlite3_value_type"), valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"), valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"), valueFloat: getFun("sqlite3_value_double"),
@@ -368,6 +373,8 @@ type sqliteAPI struct {
lastRowid api.Function lastRowid api.Function
autocommit api.Function autocommit api.Function
createCollation api.Function createCollation api.Function
createFunction api.Function
userData api.Function
valueType api.Function valueType api.Function
valueInteger api.Function valueInteger api.Function
valueFloat api.Function valueFloat api.Function

View File

@@ -11,7 +11,7 @@ import (
// //
// https://www.sqlite.org/c3ref/value.html // https://www.sqlite.org/c3ref/value.html
type Value struct { type Value struct {
c *Conn *module
handle uint32 handle uint32
} }
@@ -19,7 +19,7 @@ type Value struct {
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Type() Datatype { func (v *Value) Type() Datatype {
r := v.c.call(v.c.api.valueType, uint64(v.handle)) r := v.call(v.api.valueType, uint64(v.handle))
return Datatype(r) return Datatype(r)
} }
@@ -47,7 +47,7 @@ func (v *Value) Int() int {
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Int64() int64 { func (v *Value) Int64() int64 {
r := v.c.call(v.c.api.valueInteger, uint64(v.handle)) r := v.call(v.api.valueInteger, uint64(v.handle))
return int64(r) return int64(r)
} }
@@ -55,49 +55,44 @@ func (v *Value) Int64() int64 {
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Float() float64 { func (v *Value) Float() float64 {
r := v.c.call(v.c.api.valueFloat, uint64(v.handle)) r := v.call(v.api.valueFloat, uint64(v.handle))
return math.Float64frombits(r) return math.Float64frombits(r)
} }
// Time returns the value as a [time.Time]. // Time returns the value as a [time.Time].
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Time(format TimeFormat) (time.Time, error) { func (v *Value) Time(format TimeFormat) time.Time {
var t any var a any
var err error
switch v.Type() { switch v.Type() {
case INTEGER: case INTEGER:
t = v.Int64() a = v.Int64()
case FLOAT: case FLOAT:
t = v.Float() a = v.Float()
case TEXT, BLOB: case TEXT, BLOB:
t, err = v.Text() a = v.Text()
if err != nil {
return time.Time{}, err
}
case NULL: case NULL:
return time.Time{}, nil return time.Time{}
default: default:
panic(util.AssertErr()) panic(util.AssertErr())
} }
return format.Decode(t) t, _ := format.Decode(a)
return t
} }
// Text returns the value as a string. // Text returns the value as a string.
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Text() (string, error) { func (v *Value) Text() string {
r, err := v.RawText() return string(v.RawText())
return string(r), err
} }
// Blob appends to buf and returns // Blob appends to buf and returns
// the value as a []byte. // the value as a []byte.
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Blob(buf []byte) ([]byte, error) { func (v *Value) Blob(buf []byte) []byte {
r, err := v.RawBlob() return append(buf, v.RawBlob()...)
return append(buf, r...), err
} }
// RawText returns the value as a []byte. // RawText returns the value as a []byte.
@@ -105,8 +100,8 @@ func (v *Value) Blob(buf []byte) ([]byte, error) {
// subsequent calls to [Value] methods. // subsequent calls to [Value] methods.
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) RawText() ([]byte, error) { func (v *Value) RawText() []byte {
r := v.c.call(v.c.api.valueText, uint64(v.handle)) r := v.call(v.api.valueText, uint64(v.handle))
return v.rawBytes(uint32(r)) return v.rawBytes(uint32(r))
} }
@@ -115,17 +110,16 @@ func (v *Value) RawText() ([]byte, error) {
// subsequent calls to [Value] methods. // subsequent calls to [Value] methods.
// //
// https://www.sqlite.org/c3ref/value_blob.html // https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) RawBlob() ([]byte, error) { func (v *Value) RawBlob() []byte {
r := v.c.call(v.c.api.valueBlob, uint64(v.handle)) r := v.call(v.api.valueBlob, uint64(v.handle))
return v.rawBytes(uint32(r)) return v.rawBytes(uint32(r))
} }
func (v *Value) rawBytes(ptr uint32) ([]byte, error) { func (v *Value) rawBytes(ptr uint32) []byte {
if ptr == 0 { if ptr == 0 {
r := v.c.call(v.c.api.errcode, uint64(v.c.handle)) return nil
return nil, v.c.error(r)
} }
r := v.c.call(v.c.api.valueBytes, uint64(v.handle)) r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.c.mod, ptr, r), nil return util.View(v.mod, ptr, r)
} }