mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Custom scalar functions.
This commit is contained in:
@@ -63,6 +63,7 @@ Performance is tested by running
|
||||
### Roadmap
|
||||
|
||||
- [ ] advanced SQLite features
|
||||
- [x] custom functions
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
@@ -72,7 +73,6 @@ Performance is tested by running
|
||||
- [x] in-memory VFS
|
||||
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
|
||||
- [ ] custom SQL functions
|
||||
|
||||
### Alternatives
|
||||
|
||||
|
||||
12
const.go
12
const.go
@@ -167,6 +167,18 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_deterministic.html
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
|
||||
41
context.go
41
context.go
@@ -12,7 +12,7 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
c *Conn
|
||||
*module
|
||||
handle uint32
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (c *Context) ResultInt(value int) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultInt64(value int64) {
|
||||
c.c.call(c.c.api.resultInteger,
|
||||
c.call(c.api.resultInteger,
|
||||
uint64(c.handle), uint64(value))
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (c *Context) ResultInt64(value int64) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultFloat(value float64) {
|
||||
c.c.call(c.c.api.resultFloat,
|
||||
c.call(c.api.resultFloat,
|
||||
uint64(c.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
@@ -56,10 +56,10 @@ func (c *Context) ResultFloat(value float64) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultText(value string) {
|
||||
ptr := c.c.newString(value)
|
||||
c.c.call(c.c.api.resultText,
|
||||
ptr := c.newString(value)
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.c.api.destructor), _UTF8)
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
@@ -67,17 +67,17 @@ func (c *Context) ResultText(value string) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultBlob(value []byte) {
|
||||
ptr := c.c.newBytes(value)
|
||||
c.c.call(c.c.api.resultBlob,
|
||||
ptr := c.newBytes(value)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.c.api.destructor))
|
||||
uint64(c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultZeroBlob(n int64) {
|
||||
c.c.call(c.c.api.resultZeroBlob,
|
||||
c.call(c.api.resultZeroBlob,
|
||||
uint64(c.handle), uint64(n))
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func (c *Context) ResultZeroBlob(n int64) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultNull() {
|
||||
c.c.call(c.c.api.resultNull,
|
||||
c.call(c.api.resultNull,
|
||||
uint64(c.handle))
|
||||
}
|
||||
|
||||
@@ -112,13 +112,13 @@ func (c *Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
func (c *Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
|
||||
ptr := c.c.new(maxlen)
|
||||
buf := util.View(c.c.mod, ptr, maxlen)
|
||||
ptr := c.new(maxlen)
|
||||
buf := util.View(c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
c.c.call(c.c.api.resultText,
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(c.c.api.destructor), _UTF8)
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
@@ -126,19 +126,20 @@ func (c *Context) resultRFC3339Nano(value time.Time) {
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c *Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
c.c.call(c.c.api.resultErrorMem, uint64(c.handle))
|
||||
c.call(c.api.resultErrorMem, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
c.c.call(c.c.api.resultErrorBig, uint64(c.handle))
|
||||
c.call(c.api.resultErrorBig, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := c.c.arena.string(str)
|
||||
c.c.call(c.c.api.resultBlob,
|
||||
ptr := c.newString(str)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(str)))
|
||||
c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
@@ -150,7 +151,7 @@ func (c *Context) ResultError(err error) {
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
c.c.call(c.c.api.resultErrorCode,
|
||||
c.call(c.api.resultErrorCode,
|
||||
uint64(c.handle), uint64(xcode))
|
||||
}
|
||||
}
|
||||
|
||||
38
func.go
38
func.go
@@ -23,6 +23,18 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFunction defines a new scalar function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createFunction,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", cbDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", cbCompare)
|
||||
@@ -34,16 +46,32 @@ func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
|
||||
return env
|
||||
}
|
||||
|
||||
func cbDestroy(ctx context.Context, mod api.Module, pArg uint32) {
|
||||
util.DelHandle(ctx, pArg)
|
||||
func cbDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func cbCompare(ctx context.Context, mod api.Module, pArg, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pArg).(func(a, b []byte) int)
|
||||
func cbCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
}
|
||||
|
||||
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
|
||||
func cbFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
module := ctx.Value(moduleKey{}).(*module)
|
||||
pApp := uint32(module.call(module.api.userData, uint64(pCtx)))
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
handle: util.ReadUint32(mod, pArg+ptrlen*uint32(i)),
|
||||
module: module,
|
||||
}
|
||||
}
|
||||
context := Context{
|
||||
handle: pCtx,
|
||||
module: module,
|
||||
}
|
||||
fn := util.GetHandle(ctx, pApp).(func(ctx Context, arg ...Value))
|
||||
fn(context, args...)
|
||||
}
|
||||
|
||||
func cbStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {}
|
||||
|
||||
|
||||
55
func_test.go
55
func_test.go
@@ -1,6 +1,7 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
@@ -62,3 +63,57 @@ func ExampleConn_CreateCollation() {
|
||||
// cotée
|
||||
// coter
|
||||
}
|
||||
|
||||
func ExampleConn_CreateFunction() {
|
||||
db, err := sqlite3.Open(memory)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// COTE
|
||||
// COTÉ
|
||||
// CÔTE
|
||||
// CÔTÉ
|
||||
// COTÉE
|
||||
// COTER
|
||||
}
|
||||
|
||||
@@ -84,10 +84,13 @@ type module struct {
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
type moduleKey struct{}
|
||||
|
||||
func newModule(mod api.Module) (m *module, err error) {
|
||||
m = new(module)
|
||||
m.mod = mod
|
||||
m.ctx, m.closer = util.NewContext(context.Background())
|
||||
m.ctx = context.WithValue(m.ctx, moduleKey{}, m)
|
||||
m.mod = mod
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := mod.ExportedFunction(name)
|
||||
@@ -156,6 +159,8 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
createCollation: getFun("sqlite3_create_go_collation"),
|
||||
createFunction: getFun("sqlite3_create_go_function"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
@@ -368,6 +373,8 @@ type sqliteAPI struct {
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
userData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
|
||||
54
value.go
54
value.go
@@ -11,7 +11,7 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value.html
|
||||
type Value struct {
|
||||
c *Conn
|
||||
*module
|
||||
handle uint32
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Value struct {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Type() Datatype {
|
||||
r := v.c.call(v.c.api.valueType, uint64(v.handle))
|
||||
r := v.call(v.api.valueType, uint64(v.handle))
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (v *Value) Int() int {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Int64() int64 {
|
||||
r := v.c.call(v.c.api.valueInteger, uint64(v.handle))
|
||||
r := v.call(v.api.valueInteger, uint64(v.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
@@ -55,49 +55,44 @@ func (v *Value) Int64() int64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Float() float64 {
|
||||
r := v.c.call(v.c.api.valueFloat, uint64(v.handle))
|
||||
r := v.call(v.api.valueFloat, uint64(v.handle))
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// Time returns the value as a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Time(format TimeFormat) (time.Time, error) {
|
||||
var t any
|
||||
var err error
|
||||
func (v *Value) Time(format TimeFormat) time.Time {
|
||||
var a any
|
||||
switch v.Type() {
|
||||
case INTEGER:
|
||||
t = v.Int64()
|
||||
a = v.Int64()
|
||||
case FLOAT:
|
||||
t = v.Float()
|
||||
a = v.Float()
|
||||
case TEXT, BLOB:
|
||||
t, err = v.Text()
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
a = v.Text()
|
||||
case NULL:
|
||||
return time.Time{}, nil
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
return format.Decode(t)
|
||||
t, _ := format.Decode(a)
|
||||
return t
|
||||
}
|
||||
|
||||
// Text returns the value as a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Text() (string, error) {
|
||||
r, err := v.RawText()
|
||||
return string(r), err
|
||||
func (v *Value) Text() string {
|
||||
return string(v.RawText())
|
||||
}
|
||||
|
||||
// Blob appends to buf and returns
|
||||
// the value as a []byte.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) Blob(buf []byte) ([]byte, error) {
|
||||
r, err := v.RawBlob()
|
||||
return append(buf, r...), err
|
||||
func (v *Value) Blob(buf []byte) []byte {
|
||||
return append(buf, v.RawBlob()...)
|
||||
}
|
||||
|
||||
// RawText returns the value as a []byte.
|
||||
@@ -105,8 +100,8 @@ func (v *Value) Blob(buf []byte) ([]byte, error) {
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) RawText() ([]byte, error) {
|
||||
r := v.c.call(v.c.api.valueText, uint64(v.handle))
|
||||
func (v *Value) RawText() []byte {
|
||||
r := v.call(v.api.valueText, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
@@ -115,17 +110,16 @@ func (v *Value) RawText() ([]byte, error) {
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v *Value) RawBlob() ([]byte, error) {
|
||||
r := v.c.call(v.c.api.valueBlob, uint64(v.handle))
|
||||
func (v *Value) RawBlob() []byte {
|
||||
r := v.call(v.api.valueBlob, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
func (v *Value) rawBytes(ptr uint32) ([]byte, error) {
|
||||
func (v *Value) rawBytes(ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
r := v.c.call(v.c.api.errcode, uint64(v.c.handle))
|
||||
return nil, v.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r := v.c.call(v.c.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.c.mod, ptr, r), nil
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user