mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
451 lines
13 KiB
Go
451 lines
13 KiB
Go
// Package sqlite3 wraps the C SQLite API.
|
|
package sqlite3
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
"os"
|
|
"sync"
|
|
|
|
"github.com/ncruces/go-sqlite3/internal/util"
|
|
"github.com/ncruces/go-sqlite3/vfs"
|
|
"github.com/tetratelabs/wazero"
|
|
"github.com/tetratelabs/wazero/api"
|
|
)
|
|
|
|
// Configure SQLite WASM.
|
|
//
|
|
// Importing package embed initializes these
|
|
// with an appropriate build of SQLite:
|
|
//
|
|
// import _ "github.com/ncruces/go-sqlite3/embed"
|
|
var (
|
|
Binary []byte // WASM binary to load.
|
|
Path string // Path to load the binary from.
|
|
|
|
RuntimeConfig wazero.RuntimeConfig
|
|
)
|
|
|
|
var instance struct {
|
|
runtime wazero.Runtime
|
|
compiled wazero.CompiledModule
|
|
err error
|
|
once sync.Once
|
|
}
|
|
|
|
func compileSQLite() {
|
|
if RuntimeConfig == nil {
|
|
RuntimeConfig = wazero.NewRuntimeConfig()
|
|
}
|
|
|
|
ctx := context.Background()
|
|
instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
|
|
|
|
env := instance.runtime.NewHostModuleBuilder("env")
|
|
env = vfs.ExportHostFunctions(env)
|
|
env = exportCallbacks(env)
|
|
_, instance.err = env.Instantiate(ctx)
|
|
if instance.err != nil {
|
|
return
|
|
}
|
|
|
|
bin := Binary
|
|
if bin == nil && Path != "" {
|
|
bin, instance.err = os.ReadFile(Path)
|
|
if instance.err != nil {
|
|
return
|
|
}
|
|
}
|
|
if bin == nil {
|
|
instance.err = util.BinaryErr
|
|
return
|
|
}
|
|
|
|
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
|
}
|
|
|
|
type sqlite struct {
|
|
ctx context.Context
|
|
mod api.Module
|
|
api sqliteAPI
|
|
stack [8]uint64
|
|
}
|
|
|
|
func instantiateSQLite() (sqlt *sqlite, err error) {
|
|
instance.once.Do(compileSQLite)
|
|
if instance.err != nil {
|
|
return nil, instance.err
|
|
}
|
|
|
|
sqlt = new(sqlite)
|
|
sqlt.ctx = util.NewContext(context.Background())
|
|
|
|
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
|
instance.compiled, wazero.NewModuleConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
getFun := func(name string) api.Function {
|
|
f := sqlt.mod.ExportedFunction(name)
|
|
if f == nil {
|
|
err = util.NoFuncErr + util.ErrorString(name)
|
|
return nil
|
|
}
|
|
return f
|
|
}
|
|
|
|
getVal := func(name string) uint32 {
|
|
g := sqlt.mod.ExportedGlobal(name)
|
|
if g == nil {
|
|
err = util.NoGlobalErr + util.ErrorString(name)
|
|
return 0
|
|
}
|
|
return util.ReadUint32(sqlt.mod, uint32(g.Get()))
|
|
}
|
|
|
|
sqlt.api = sqliteAPI{
|
|
free: getFun("free"),
|
|
malloc: getFun("malloc"),
|
|
destructor: getVal("malloc_destructor"),
|
|
errcode: getFun("sqlite3_errcode"),
|
|
errstr: getFun("sqlite3_errstr"),
|
|
errmsg: getFun("sqlite3_errmsg"),
|
|
erroff: getFun("sqlite3_error_offset"),
|
|
open: getFun("sqlite3_open_v2"),
|
|
close: getFun("sqlite3_close"),
|
|
closeZombie: getFun("sqlite3_close_v2"),
|
|
prepare: getFun("sqlite3_prepare_v3"),
|
|
finalize: getFun("sqlite3_finalize"),
|
|
reset: getFun("sqlite3_reset"),
|
|
step: getFun("sqlite3_step"),
|
|
exec: getFun("sqlite3_exec"),
|
|
interrupt: getFun("sqlite3_interrupt"),
|
|
progressHandler: getFun("sqlite3_progress_handler_go"),
|
|
clearBindings: getFun("sqlite3_clear_bindings"),
|
|
bindCount: getFun("sqlite3_bind_parameter_count"),
|
|
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
|
bindName: getFun("sqlite3_bind_parameter_name"),
|
|
bindNull: getFun("sqlite3_bind_null"),
|
|
bindInteger: getFun("sqlite3_bind_int64"),
|
|
bindFloat: getFun("sqlite3_bind_double"),
|
|
bindText: getFun("sqlite3_bind_text64"),
|
|
bindBlob: getFun("sqlite3_bind_blob64"),
|
|
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
|
bindPointer: getFun("sqlite3_bind_pointer_go"),
|
|
columnCount: getFun("sqlite3_column_count"),
|
|
columnName: getFun("sqlite3_column_name"),
|
|
columnType: getFun("sqlite3_column_type"),
|
|
columnInteger: getFun("sqlite3_column_int64"),
|
|
columnFloat: getFun("sqlite3_column_double"),
|
|
columnText: getFun("sqlite3_column_text"),
|
|
columnBlob: getFun("sqlite3_column_blob"),
|
|
columnBytes: getFun("sqlite3_column_bytes"),
|
|
blobOpen: getFun("sqlite3_blob_open"),
|
|
blobClose: getFun("sqlite3_blob_close"),
|
|
blobReopen: getFun("sqlite3_blob_reopen"),
|
|
blobBytes: getFun("sqlite3_blob_bytes"),
|
|
blobRead: getFun("sqlite3_blob_read"),
|
|
blobWrite: getFun("sqlite3_blob_write"),
|
|
backupInit: getFun("sqlite3_backup_init"),
|
|
backupStep: getFun("sqlite3_backup_step"),
|
|
backupFinish: getFun("sqlite3_backup_finish"),
|
|
backupRemaining: getFun("sqlite3_backup_remaining"),
|
|
backupPageCount: getFun("sqlite3_backup_pagecount"),
|
|
changes: getFun("sqlite3_changes64"),
|
|
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
|
autocommit: getFun("sqlite3_get_autocommit"),
|
|
anyCollation: getFun("sqlite3_anycollseq_init"),
|
|
createCollation: getFun("sqlite3_create_collation_go"),
|
|
createFunction: getFun("sqlite3_create_function_go"),
|
|
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
|
createWindow: getFun("sqlite3_create_window_function_go"),
|
|
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
|
userData: getFun("sqlite3_user_data"),
|
|
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
|
getAuxData: getFun("sqlite3_get_auxdata"),
|
|
valueType: getFun("sqlite3_value_type"),
|
|
valueInteger: getFun("sqlite3_value_int64"),
|
|
valueFloat: getFun("sqlite3_value_double"),
|
|
valueText: getFun("sqlite3_value_text"),
|
|
valueBlob: getFun("sqlite3_value_blob"),
|
|
valueBytes: getFun("sqlite3_value_bytes"),
|
|
valuePointer: getFun("sqlite3_value_pointer_go"),
|
|
resultNull: getFun("sqlite3_result_null"),
|
|
resultInteger: getFun("sqlite3_result_int64"),
|
|
resultFloat: getFun("sqlite3_result_double"),
|
|
resultText: getFun("sqlite3_result_text64"),
|
|
resultBlob: getFun("sqlite3_result_blob64"),
|
|
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
|
resultPointer: getFun("sqlite3_result_pointer_go"),
|
|
resultValue: getFun("sqlite3_result_value"),
|
|
resultError: getFun("sqlite3_result_error"),
|
|
resultErrorCode: getFun("sqlite3_result_error_code"),
|
|
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
|
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
|
createModule: getFun("sqlite3_create_module_go"),
|
|
declareVTab: getFun("sqlite3_declare_vtab"),
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sqlt, nil
|
|
}
|
|
|
|
func (sqlt *sqlite) close() error {
|
|
return sqlt.mod.Close(sqlt.ctx)
|
|
}
|
|
|
|
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
|
if rc == _OK {
|
|
return nil
|
|
}
|
|
|
|
err := Error{code: rc}
|
|
|
|
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
|
|
panic(util.OOMErr)
|
|
}
|
|
|
|
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
|
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
|
}
|
|
|
|
if handle != 0 {
|
|
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
|
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
|
}
|
|
|
|
if sql != nil {
|
|
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
|
err.sql = sql[0][r:]
|
|
}
|
|
}
|
|
}
|
|
|
|
switch err.msg {
|
|
case err.str, "not an error":
|
|
err.msg = ""
|
|
}
|
|
return &err
|
|
}
|
|
|
|
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
|
copy(sqlt.stack[:], params)
|
|
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return sqlt.stack[0]
|
|
}
|
|
|
|
func (sqlt *sqlite) free(ptr uint32) {
|
|
if ptr == 0 {
|
|
return
|
|
}
|
|
sqlt.call(sqlt.api.free, uint64(ptr))
|
|
}
|
|
|
|
func (sqlt *sqlite) new(size uint64) uint32 {
|
|
if size > _MAX_ALLOCATION_SIZE {
|
|
panic(util.OOMErr)
|
|
}
|
|
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
|
if ptr == 0 && size != 0 {
|
|
panic(util.OOMErr)
|
|
}
|
|
return ptr
|
|
}
|
|
|
|
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
|
if (*[0]byte)(b) == nil {
|
|
return 0
|
|
}
|
|
ptr := sqlt.new(uint64(len(b)))
|
|
util.WriteBytes(sqlt.mod, ptr, b)
|
|
return ptr
|
|
}
|
|
|
|
func (sqlt *sqlite) newString(s string) uint32 {
|
|
ptr := sqlt.new(uint64(len(s) + 1))
|
|
util.WriteString(sqlt.mod, ptr, s)
|
|
return ptr
|
|
}
|
|
|
|
func (sqlt *sqlite) newArena(size uint64) arena {
|
|
return arena{
|
|
sqlt: sqlt,
|
|
size: uint32(size),
|
|
base: sqlt.new(size),
|
|
}
|
|
}
|
|
|
|
type arena struct {
|
|
sqlt *sqlite
|
|
ptrs []uint32
|
|
base uint32
|
|
next uint32
|
|
size uint32
|
|
}
|
|
|
|
func (a *arena) free() {
|
|
if a.sqlt == nil {
|
|
return
|
|
}
|
|
a.reset()
|
|
a.sqlt.free(a.base)
|
|
a.sqlt = nil
|
|
}
|
|
|
|
func (a *arena) reset() {
|
|
for _, ptr := range a.ptrs {
|
|
a.sqlt.free(ptr)
|
|
}
|
|
a.ptrs = nil
|
|
a.next = 0
|
|
}
|
|
|
|
func (a *arena) new(size uint64) uint32 {
|
|
if size <= uint64(a.size-a.next) {
|
|
ptr := a.base + a.next
|
|
a.next += uint32(size)
|
|
return ptr
|
|
}
|
|
ptr := a.sqlt.new(size)
|
|
a.ptrs = append(a.ptrs, ptr)
|
|
return ptr
|
|
}
|
|
|
|
func (a *arena) bytes(b []byte) uint32 {
|
|
if b == nil {
|
|
return 0
|
|
}
|
|
ptr := a.new(uint64(len(b)))
|
|
util.WriteBytes(a.sqlt.mod, ptr, b)
|
|
return ptr
|
|
}
|
|
|
|
func (a *arena) string(s string) uint32 {
|
|
ptr := a.new(uint64(len(s) + 1))
|
|
util.WriteString(a.sqlt.mod, ptr, s)
|
|
return ptr
|
|
}
|
|
|
|
type sqliteAPI struct {
|
|
free api.Function
|
|
malloc api.Function
|
|
errcode api.Function
|
|
errstr api.Function
|
|
errmsg api.Function
|
|
erroff api.Function
|
|
open api.Function
|
|
close api.Function
|
|
closeZombie api.Function
|
|
prepare api.Function
|
|
finalize api.Function
|
|
reset api.Function
|
|
step api.Function
|
|
exec api.Function
|
|
interrupt api.Function
|
|
progressHandler api.Function
|
|
clearBindings api.Function
|
|
bindCount api.Function
|
|
bindIndex api.Function
|
|
bindName api.Function
|
|
bindNull api.Function
|
|
bindInteger api.Function
|
|
bindFloat api.Function
|
|
bindText api.Function
|
|
bindBlob api.Function
|
|
bindZeroBlob api.Function
|
|
bindPointer api.Function
|
|
columnCount api.Function
|
|
columnName api.Function
|
|
columnType api.Function
|
|
columnInteger api.Function
|
|
columnFloat api.Function
|
|
columnText api.Function
|
|
columnBlob api.Function
|
|
columnBytes api.Function
|
|
blobOpen api.Function
|
|
blobClose api.Function
|
|
blobReopen api.Function
|
|
blobBytes api.Function
|
|
blobRead api.Function
|
|
blobWrite api.Function
|
|
backupInit api.Function
|
|
backupStep api.Function
|
|
backupFinish api.Function
|
|
backupRemaining api.Function
|
|
backupPageCount api.Function
|
|
changes api.Function
|
|
lastRowid api.Function
|
|
autocommit api.Function
|
|
anyCollation api.Function
|
|
createCollation api.Function
|
|
createFunction api.Function
|
|
createAggregate api.Function
|
|
createWindow api.Function
|
|
aggregateCtx api.Function
|
|
userData api.Function
|
|
setAuxData api.Function
|
|
getAuxData api.Function
|
|
valueType api.Function
|
|
valueInteger api.Function
|
|
valueFloat api.Function
|
|
valueText api.Function
|
|
valueBlob api.Function
|
|
valueBytes api.Function
|
|
valuePointer api.Function
|
|
resultNull api.Function
|
|
resultInteger api.Function
|
|
resultFloat api.Function
|
|
resultText api.Function
|
|
resultBlob api.Function
|
|
resultZeroBlob api.Function
|
|
resultPointer api.Function
|
|
resultValue api.Function
|
|
resultError api.Function
|
|
resultErrorCode api.Function
|
|
resultErrorMem api.Function
|
|
resultErrorBig api.Function
|
|
createModule api.Function
|
|
declareVTab api.Function
|
|
destructor uint32
|
|
}
|
|
|
|
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
|
util.ExportFuncII(env, "go_progress", progressCallback)
|
|
util.ExportFuncVI(env, "go_destroy", destroyCallback)
|
|
util.ExportFuncVIII(env, "go_func", funcCallback)
|
|
util.ExportFuncVIII(env, "go_step", stepCallback)
|
|
util.ExportFuncVI(env, "go_final", finalCallback)
|
|
util.ExportFuncVI(env, "go_value", valueCallback)
|
|
util.ExportFuncVIII(env, "go_inverse", inverseCallback)
|
|
util.ExportFuncIIIIII(env, "go_compare", compareCallback)
|
|
util.ExportFuncIIIIII(env, "go_vtab_create", vtabReflectCallback("Create"))
|
|
util.ExportFuncIIIIII(env, "go_vtab_connect", vtabReflectCallback("Connect"))
|
|
util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
|
|
util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
|
|
util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
|
|
util.ExportFuncIIIII(env, "go_vtab_update", vtabCallbackIIII)
|
|
util.ExportFuncIII(env, "go_vtab_rename", vtabCallbackII)
|
|
util.ExportFuncIIIII(env, "go_vtab_find_function", vtabCallbackIIII)
|
|
util.ExportFuncII(env, "go_vtab_begin", vtabCallbackI)
|
|
util.ExportFuncII(env, "go_vtab_sync", vtabCallbackI)
|
|
util.ExportFuncII(env, "go_vtab_commit", vtabCallbackI)
|
|
util.ExportFuncII(env, "go_vtab_rollback", vtabCallbackI)
|
|
util.ExportFuncIII(env, "go_vtab_savepoint", vtabCallbackII)
|
|
util.ExportFuncIII(env, "go_vtab_release", vtabCallbackII)
|
|
util.ExportFuncIII(env, "go_vtab_rollback_to", vtabCallbackII)
|
|
util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
|
|
util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
|
|
util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
|
|
util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
|
|
util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
|
|
util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
|
|
util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
|
|
util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
|
|
return env
|
|
}
|