// Package sqlite3 wraps the C SQLite API. package sqlite3 import ( "context" "io" "math" "os" "sync" "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/sqlite3vfs" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) // Configure SQLite WASM. // // Importing package embed initializes these // with an appropriate build of SQLite: // // import _ "github.com/ncruces/go-sqlite3/embed" var ( Binary []byte // WASM binary to load. Path string // Path to load the binary from. ) var sqlite3 struct { runtime wazero.Runtime compiled wazero.CompiledModule err error once sync.Once } func instantiateModule() (*module, error) { ctx := context.Background() sqlite3.once.Do(compileModule) if sqlite3.err != nil { return nil, sqlite3.err } cfg := wazero.NewModuleConfig() mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg) if err != nil { return nil, err } return newModule(mod) } func compileModule() { ctx := context.Background() sqlite3.runtime = wazero.NewRuntime(ctx) env := sqlite3vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env")) _, sqlite3.err = env.Instantiate(ctx) if sqlite3.err != nil { return } bin := Binary if bin == nil && Path != "" { bin, sqlite3.err = os.ReadFile(Path) if sqlite3.err != nil { return } } if bin == nil { sqlite3.err = util.BinaryErr return } sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin) } type module struct { ctx context.Context mod api.Module vfs io.Closer api sqliteAPI arg [8]uint64 } func newModule(mod api.Module) (m *module, err error) { m = &module{} m.mod = mod m.ctx, m.vfs = sqlite3vfs.NewContext(context.Background()) getFun := func(name string) api.Function { f := mod.ExportedFunction(name) if f == nil { err = util.NoFuncErr + util.ErrorString(name) return nil } return f } getVal := func(name string) uint32 { g := mod.ExportedGlobal(name) if g == nil { err = util.NoGlobalErr + util.ErrorString(name) return 0 } return util.ReadUint32(mod, uint32(g.Get())) } m.api = sqliteAPI{ free: getFun("free"), malloc: getFun("malloc"), destructor: getVal("malloc_destructor"), errcode: getFun("sqlite3_errcode"), errstr: getFun("sqlite3_errstr"), errmsg: getFun("sqlite3_errmsg"), erroff: getFun("sqlite3_error_offset"), open: getFun("sqlite3_open_v2"), close: getFun("sqlite3_close"), closeZombie: getFun("sqlite3_close_v2"), prepare: getFun("sqlite3_prepare_v3"), finalize: getFun("sqlite3_finalize"), reset: getFun("sqlite3_reset"), step: getFun("sqlite3_step"), exec: getFun("sqlite3_exec"), clearBindings: getFun("sqlite3_clear_bindings"), bindCount: getFun("sqlite3_bind_parameter_count"), bindIndex: getFun("sqlite3_bind_parameter_index"), bindName: getFun("sqlite3_bind_parameter_name"), bindNull: getFun("sqlite3_bind_null"), bindInteger: getFun("sqlite3_bind_int64"), bindFloat: getFun("sqlite3_bind_double"), bindText: getFun("sqlite3_bind_text64"), bindBlob: getFun("sqlite3_bind_blob64"), bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), columnCount: getFun("sqlite3_column_count"), columnName: getFun("sqlite3_column_name"), columnType: getFun("sqlite3_column_type"), columnInteger: getFun("sqlite3_column_int64"), columnFloat: getFun("sqlite3_column_double"), columnText: getFun("sqlite3_column_text"), columnBlob: getFun("sqlite3_column_blob"), columnBytes: getFun("sqlite3_column_bytes"), autocommit: getFun("sqlite3_get_autocommit"), lastRowid: getFun("sqlite3_last_insert_rowid"), changes: getFun("sqlite3_changes64"), blobOpen: getFun("sqlite3_blob_open"), blobClose: getFun("sqlite3_blob_close"), blobReopen: getFun("sqlite3_blob_reopen"), blobBytes: getFun("sqlite3_blob_bytes"), blobRead: getFun("sqlite3_blob_read"), blobWrite: getFun("sqlite3_blob_write"), backupInit: getFun("sqlite3_backup_init"), backupStep: getFun("sqlite3_backup_step"), backupFinish: getFun("sqlite3_backup_finish"), backupRemaining: getFun("sqlite3_backup_remaining"), backupPageCount: getFun("sqlite3_backup_pagecount"), } if err != nil { return nil, err } return m, nil } func (m *module) close() error { err := m.mod.Close(m.ctx) m.vfs.Close() return err } func (m *module) error(rc uint64, handle uint32, sql ...string) error { if rc == _OK { return nil } err := Error{code: rc} if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { panic(util.OOMErr) } var r []uint64 r = m.call(m.api.errstr, rc) if r != nil { err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING) } r = m.call(m.api.errmsg, uint64(handle)) if r != nil { err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING) } if sql != nil { r = m.call(m.api.erroff, uint64(handle)) if r != nil && r[0] != math.MaxUint32 { err.sql = sql[0][r[0]:] } } switch err.msg { case err.str, "not an error": err.msg = "" } return &err } func (m *module) call(fn api.Function, params ...uint64) []uint64 { copy(m.arg[:], params) err := fn.CallWithStack(m.ctx, m.arg[:]) if err != nil { // The module closed or panicked; release resources. m.vfs.Close() panic(err) } return m.arg[:] } func (m *module) free(ptr uint32) { if ptr == 0 { return } m.call(m.api.free, uint64(ptr)) } func (m *module) new(size uint64) uint32 { if size > _MAX_ALLOCATION_SIZE { panic(util.OOMErr) } r := m.call(m.api.malloc, size) ptr := uint32(r[0]) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } func (m *module) newBytes(b []byte) uint32 { if b == nil { return 0 } ptr := m.new(uint64(len(b))) util.WriteBytes(m.mod, ptr, b) return ptr } func (m *module) newString(s string) uint32 { ptr := m.new(uint64(len(s) + 1)) util.WriteString(m.mod, ptr, s) return ptr } func (m *module) newArena(size uint64) arena { return arena{ m: m, base: m.new(size), size: uint32(size), } } type arena struct { m *module ptrs []uint32 base uint32 next uint32 size uint32 } func (a *arena) free() { if a.m == nil { return } a.reset() a.m.free(a.base) a.m = nil } func (a *arena) reset() { for _, ptr := range a.ptrs { a.m.free(ptr) } a.ptrs = nil a.next = 0 } func (a *arena) new(size uint64) uint32 { if size <= uint64(a.size-a.next) { ptr := a.base + a.next a.next += uint32(size) return ptr } ptr := a.m.new(size) a.ptrs = append(a.ptrs, ptr) return ptr } func (a *arena) bytes(b []byte) uint32 { if b == nil { return 0 } ptr := a.new(uint64(len(b))) util.WriteBytes(a.m.mod, ptr, b) return ptr } func (a *arena) string(s string) uint32 { ptr := a.new(uint64(len(s) + 1)) util.WriteString(a.m.mod, ptr, s) return ptr } type sqliteAPI struct { free api.Function malloc api.Function errcode api.Function errstr api.Function errmsg api.Function erroff api.Function open api.Function close api.Function closeZombie api.Function prepare api.Function finalize api.Function reset api.Function step api.Function exec api.Function clearBindings api.Function bindNull api.Function bindCount api.Function bindIndex api.Function bindName api.Function bindInteger api.Function bindFloat api.Function bindText api.Function bindBlob api.Function bindZeroBlob api.Function columnCount api.Function columnName api.Function columnType api.Function columnInteger api.Function columnFloat api.Function columnText api.Function columnBlob api.Function columnBytes api.Function autocommit api.Function lastRowid api.Function changes api.Function blobOpen api.Function blobClose api.Function blobReopen api.Function blobBytes api.Function blobRead api.Function blobWrite api.Function backupInit api.Function backupStep api.Function backupFinish api.Function backupRemaining api.Function backupPageCount api.Function destructor uint32 }