// Package sqlite3 wraps the C SQLite API. package sqlite3 import ( "context" "math" "os" "sync" "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/vfs" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) // Configure SQLite WASM. // // Importing package embed initializes these // with an appropriate build of SQLite: // // import _ "github.com/ncruces/go-sqlite3/embed" var ( Binary []byte // WASM binary to load. Path string // Path to load the binary from. RuntimeConfig wazero.RuntimeConfig ) var instance struct { runtime wazero.Runtime compiled wazero.CompiledModule err error once sync.Once } func compileSQLite() { if RuntimeConfig == nil { RuntimeConfig = wazero.NewRuntimeConfig() } ctx := context.Background() instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig) env := instance.runtime.NewHostModuleBuilder("env") env = vfs.ExportHostFunctions(env) env = exportCallbacks(env) _, instance.err = env.Instantiate(ctx) if instance.err != nil { return } bin := Binary if bin == nil && Path != "" { bin, instance.err = os.ReadFile(Path) if instance.err != nil { return } } if bin == nil { instance.err = util.NoBinaryErr return } instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin) } type sqlite struct { ctx context.Context mod api.Module funcs [8]api.Function stack [8]uint64 freer uint32 } func instantiateSQLite() (sqlt *sqlite, err error) { instance.once.Do(compileSQLite) if instance.err != nil { return nil, instance.err } sqlt = new(sqlite) sqlt.ctx = util.NewContext(context.Background()) sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx, instance.compiled, wazero.NewModuleConfig()) if err != nil { return nil, err } global := sqlt.mod.ExportedGlobal("malloc_destructor") if global == nil { return nil, util.BadBinaryErr } sqlt.freer = util.ReadUint32(sqlt.mod, uint32(global.Get())) if err != nil { return nil, err } return sqlt, nil } func (sqlt *sqlite) close() error { return sqlt.mod.Close(sqlt.ctx) } func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { if rc == _OK { return nil } err := Error{code: rc} if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { panic(util.OOMErr) } if r := sqlt.call("sqlite3_errstr", rc); r != 0 { err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) } if handle != 0 { if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 { err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) } if sql != nil { if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { err.sql = sql[0][r:] } } } switch err.msg { case err.str, "not an error": err.msg = "" } return &err } func (sqlt *sqlite) getfn(name string) (api.Function, uint32) { // https://cr.yp.to/cdb/cdb.txt hash := func(s string) uint32 { var hash uint32 = 5381 for _, b := range []byte(s) { hash = (hash<<5 + hash) ^ uint32(b) } return hash }(name) % uint32(len(sqlt.funcs)) fn := sqlt.funcs[hash] if fn == nil || name != fn.Definition().Name() { fn = sqlt.mod.ExportedFunction(name) } else { sqlt.funcs[hash] = nil } return fn, hash } func (sqlt *sqlite) call(name string, params ...uint64) uint64 { copy(sqlt.stack[:], params) fn, hash := sqlt.getfn(name) err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) if err != nil { panic(err) } sqlt.funcs[hash] = fn return sqlt.stack[0] } func (sqlt *sqlite) free(ptr uint32) { if ptr == 0 { return } sqlt.call("free", uint64(ptr)) } func (sqlt *sqlite) new(size uint64) uint32 { if size > _MAX_ALLOCATION_SIZE { panic(util.OOMErr) } ptr := uint32(sqlt.call("malloc", size)) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } func (sqlt *sqlite) newBytes(b []byte) uint32 { if (*[0]byte)(b) == nil { return 0 } ptr := sqlt.new(uint64(len(b))) util.WriteBytes(sqlt.mod, ptr, b) return ptr } func (sqlt *sqlite) newString(s string) uint32 { ptr := sqlt.new(uint64(len(s) + 1)) util.WriteString(sqlt.mod, ptr, s) return ptr } func (sqlt *sqlite) newArena(size uint64) arena { return arena{ sqlt: sqlt, size: uint32(size), base: sqlt.new(size), } } type arena struct { sqlt *sqlite ptrs []uint32 base uint32 next uint32 size uint32 } func (a *arena) free() { if a.sqlt == nil { return } for _, ptr := range a.ptrs { a.sqlt.free(ptr) } a.sqlt.free(a.base) a.sqlt = nil } func (a *arena) mark() (reset func()) { ptrs := len(a.ptrs) next := a.next return func() { for _, ptr := range a.ptrs[ptrs:] { a.sqlt.free(ptr) } a.ptrs = a.ptrs[:ptrs] a.next = next } } func (a *arena) new(size uint64) uint32 { if size <= uint64(a.size-a.next) { ptr := a.base + a.next a.next += uint32(size) return ptr } ptr := a.sqlt.new(size) a.ptrs = append(a.ptrs, ptr) return ptr } func (a *arena) bytes(b []byte) uint32 { if b == nil { return 0 } ptr := a.new(uint64(len(b))) util.WriteBytes(a.sqlt.mod, ptr, b) return ptr } func (a *arena) string(s string) uint32 { ptr := a.new(uint64(len(s) + 1)) util.WriteString(a.sqlt.mod, ptr, s) return ptr } func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncII(env, "go_progress", progressCallback) util.ExportFuncVI(env, "go_destroy", destroyCallback) util.ExportFuncVIII(env, "go_func", funcCallback) util.ExportFuncVIII(env, "go_step", stepCallback) util.ExportFuncVI(env, "go_final", finalCallback) util.ExportFuncVI(env, "go_value", valueCallback) util.ExportFuncVIII(env, "go_inverse", inverseCallback) util.ExportFuncIIIIII(env, "go_compare", compareCallback) util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(0)) util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(1)) util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback) util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback) util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback) util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback) util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback) util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback) util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback) util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback) util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback) util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback) util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback) util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback) util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback) util.ExportFuncII(env, "go_cur_close", cursorCloseCallback) util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback) util.ExportFuncII(env, "go_cur_next", cursorNextCallback) util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback) util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback) util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback) return env }