Custom scalar functions.

This commit is contained in:
Nuno Cruces
2023-07-01 00:15:28 +01:00
parent 31572e6095
commit fec1f8d32a
7 changed files with 154 additions and 57 deletions

View File

@@ -63,6 +63,7 @@ Performance is tested by running
### Roadmap
- [ ] advanced SQLite features
- [x] custom functions
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
@@ -72,7 +73,6 @@ Performance is tested by running
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives

View File

@@ -167,6 +167,18 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
SUBTYPE FunctionFlag = 0x000100000
INNOCUOUS FunctionFlag = 0x000200000
)
// Datatype is a fundamental datatype of SQLite.
//
// https://www.sqlite.org/c3ref/c_blob.html

View File

@@ -12,7 +12,7 @@ import (
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
c *Conn
*module
handle uint32
}
@@ -40,7 +40,7 @@ func (c *Context) ResultInt(value int) {
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultInt64(value int64) {
c.c.call(c.c.api.resultInteger,
c.call(c.api.resultInteger,
uint64(c.handle), uint64(value))
}
@@ -48,7 +48,7 @@ func (c *Context) ResultInt64(value int64) {
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultFloat(value float64) {
c.c.call(c.c.api.resultFloat,
c.call(c.api.resultFloat,
uint64(c.handle), math.Float64bits(value))
}
@@ -56,10 +56,10 @@ func (c *Context) ResultFloat(value float64) {
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultText(value string) {
ptr := c.c.newString(value)
c.c.call(c.c.api.resultText,
ptr := c.newString(value)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.c.api.destructor), _UTF8)
uint64(c.api.destructor), _UTF8)
}
// ResultBlob sets the result of the function to a []byte.
@@ -67,17 +67,17 @@ func (c *Context) ResultText(value string) {
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultBlob(value []byte) {
ptr := c.c.newBytes(value)
c.c.call(c.c.api.resultBlob,
ptr := c.newBytes(value)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(value)),
uint64(c.c.api.destructor))
uint64(c.api.destructor))
}
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultZeroBlob(n int64) {
c.c.call(c.c.api.resultZeroBlob,
c.call(c.api.resultZeroBlob,
uint64(c.handle), uint64(n))
}
@@ -85,7 +85,7 @@ func (c *Context) ResultZeroBlob(n int64) {
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultNull() {
c.c.call(c.c.api.resultNull,
c.call(c.api.resultNull,
uint64(c.handle))
}
@@ -112,13 +112,13 @@ func (c *Context) ResultTime(value time.Time, format TimeFormat) {
func (c *Context) resultRFC3339Nano(value time.Time) {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := c.c.new(maxlen)
buf := util.View(c.c.mod, ptr, maxlen)
ptr := c.new(maxlen)
buf := util.View(c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
c.c.call(c.c.api.resultText,
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(buf)),
uint64(c.c.api.destructor), _UTF8)
uint64(c.api.destructor), _UTF8)
}
// ResultError sets the result of the function an error.
@@ -126,19 +126,20 @@ func (c *Context) resultRFC3339Nano(value time.Time) {
// https://www.sqlite.org/c3ref/result_blob.html
func (c *Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
c.c.call(c.c.api.resultErrorMem, uint64(c.handle))
c.call(c.api.resultErrorMem, uint64(c.handle))
return
}
if errors.Is(err, TOOBIG) {
c.c.call(c.c.api.resultErrorBig, uint64(c.handle))
c.call(c.api.resultErrorBig, uint64(c.handle))
return
}
str := err.Error()
ptr := c.c.arena.string(str)
c.c.call(c.c.api.resultBlob,
ptr := c.newString(str)
c.call(c.api.resultBlob,
uint64(c.handle), uint64(ptr), uint64(len(str)))
c.free(ptr)
var code uint64
var ecode ErrorCode
@@ -150,7 +151,7 @@ func (c *Context) ResultError(err error) {
code = uint64(ecode)
}
if code != 0 {
c.c.call(c.c.api.resultErrorCode,
c.call(c.api.resultErrorCode,
uint64(c.handle), uint64(xcode))
}
}

38
func.go
View File

@@ -23,6 +23,18 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
return nil
}
// CreateFunction defines a new scalar function.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
r := c.call(c.api.createFunction,
uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr))
return c.error(r)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", cbDestroy)
util.ExportFuncIIIIII(env, "go_compare", cbCompare)
@@ -34,16 +46,32 @@ func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
return env
}
func cbDestroy(ctx context.Context, mod api.Module, pArg uint32) {
util.DelHandle(ctx, pArg)
func cbDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
func cbCompare(ctx context.Context, mod api.Module, pArg, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pArg).(func(a, b []byte) int)
func cbCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
}
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
module := ctx.Value(moduleKey{}).(*module)
pApp := uint32(module.call(module.api.userData, uint64(pCtx)))
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
handle: util.ReadUint32(mod, pArg+ptrlen*uint32(i)),
module: module,
}
}
context := Context{
handle: pCtx,
module: module,
}
fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value))
fn(context, args...)
}
func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}

View File

@@ -1,6 +1,7 @@
package sqlite3_test
import (
"bytes"
"fmt"
"log"
@@ -62,3 +63,57 @@ func ExampleConn_CreateCollation() {
// cotée
// coter
}
func ExampleConn_CreateFunction() {
db, err := sqlite3.Open(memory)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
if err != nil {
log.Fatal(err)
}
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
})
if err != nil {
log.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnText(0))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
err = stmt.Close()
if err != nil {
log.Fatal(err)
}
err = db.Close()
if err != nil {
log.Fatal(err)
}
// Unordered output:
// COTE
// COTÉ
// CÔTE
// CÔTÉ
// COTÉE
// COTER
}

View File

@@ -84,10 +84,13 @@ type module struct {
stack [8]uint64
}
type moduleKey struct{}
func newModule(mod api.Module) (m *module, err error) {
m = new(module)
m.mod = mod
m.ctx, m.closer = util.NewContext(context.Background())
m.ctx = context.WithValue(m.ctx, moduleKey{}, m)
m.mod = mod
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
@@ -156,6 +159,8 @@ func newModule(mod api.Module) (m *module, err error) {
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
createCollation: getFun("sqlite3_create_go_collation"),
createFunction: getFun("sqlite3_create_go_function"),
userData: getFun("sqlite3_user_data"),
valueType: getFun("sqlite3_value_type"),
valueInteger: getFun("sqlite3_value_int64"),
valueFloat: getFun("sqlite3_value_double"),
@@ -368,6 +373,8 @@ type sqliteAPI struct {
lastRowid api.Function
autocommit api.Function
createCollation api.Function
createFunction api.Function
userData api.Function
valueType api.Function
valueInteger api.Function
valueFloat api.Function

View File

@@ -11,7 +11,7 @@ import (
//
// https://www.sqlite.org/c3ref/value.html
type Value struct {
c *Conn
*module
handle uint32
}
@@ -19,7 +19,7 @@ type Value struct {
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Type() Datatype {
r := v.c.call(v.c.api.valueType, uint64(v.handle))
r := v.call(v.api.valueType, uint64(v.handle))
return Datatype(r)
}
@@ -47,7 +47,7 @@ func (v *Value) Int() int {
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Int64() int64 {
r := v.c.call(v.c.api.valueInteger, uint64(v.handle))
r := v.call(v.api.valueInteger, uint64(v.handle))
return int64(r)
}
@@ -55,49 +55,44 @@ func (v *Value) Int64() int64 {
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Float() float64 {
r := v.c.call(v.c.api.valueFloat, uint64(v.handle))
r := v.call(v.api.valueFloat, uint64(v.handle))
return math.Float64frombits(r)
}
// Time returns the value as a [time.Time].
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Time(format TimeFormat) (time.Time, error) {
var t any
var err error
func (v *Value) Time(format TimeFormat) time.Time {
var a any
switch v.Type() {
case INTEGER:
t = v.Int64()
a = v.Int64()
case FLOAT:
t = v.Float()
a = v.Float()
case TEXT, BLOB:
t, err = v.Text()
if err != nil {
return time.Time{}, err
}
a = v.Text()
case NULL:
return time.Time{}, nil
return time.Time{}
default:
panic(util.AssertErr())
}
return format.Decode(t)
t, _ := format.Decode(a)
return t
}
// Text returns the value as a string.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Text() (string, error) {
r, err := v.RawText()
return string(r), err
func (v *Value) Text() string {
return string(v.RawText())
}
// Blob appends to buf and returns
// the value as a []byte.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) Blob(buf []byte) ([]byte, error) {
r, err := v.RawBlob()
return append(buf, r...), err
func (v *Value) Blob(buf []byte) []byte {
return append(buf, v.RawBlob()...)
}
// RawText returns the value as a []byte.
@@ -105,8 +100,8 @@ func (v *Value) Blob(buf []byte) ([]byte, error) {
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) RawText() ([]byte, error) {
r := v.c.call(v.c.api.valueText, uint64(v.handle))
func (v *Value) RawText() []byte {
r := v.call(v.api.valueText, uint64(v.handle))
return v.rawBytes(uint32(r))
}
@@ -115,17 +110,16 @@ func (v *Value) RawText() ([]byte, error) {
// subsequent calls to [Value] methods.
//
// https://www.sqlite.org/c3ref/value_blob.html
func (v *Value) RawBlob() ([]byte, error) {
r := v.c.call(v.c.api.valueBlob, uint64(v.handle))
func (v *Value) RawBlob() []byte {
r := v.call(v.api.valueBlob, uint64(v.handle))
return v.rawBytes(uint32(r))
}
func (v *Value) rawBytes(ptr uint32) ([]byte, error) {
func (v *Value) rawBytes(ptr uint32) []byte {
if ptr == 0 {
r := v.c.call(v.c.api.errcode, uint64(v.c.handle))
return nil, v.c.error(r)
return nil
}
r := v.c.call(v.c.api.valueBytes, uint64(v.handle))
return util.View(v.c.mod, ptr, r), nil
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}