diff --git a/api.go b/api.go deleted file mode 100644 index cfd2450..0000000 --- a/api.go +++ /dev/null @@ -1,127 +0,0 @@ -// Package sqlite3 wraps the C SQLite API. -package sqlite3 - -import "github.com/tetratelabs/wazero/api" - -func (module *module) loadAPI() (err error) { - getFun := func(name string) api.Function { - f := module.ExportedFunction(name) - if f == nil { - err = noFuncErr + errorString(name) - return nil - } - return f - } - - getVal := func(name string) uint32 { - global := module.ExportedGlobal(name) - if global == nil { - err = noGlobalErr + errorString(name) - return 0 - } - return module.mem.readUint32(uint32(global.Get())) - } - - module.api = sqliteAPI{ - free: getFun("free"), - malloc: getFun("malloc"), - destructor: uint64(getVal("malloc_destructor")), - errcode: getFun("sqlite3_errcode"), - errstr: getFun("sqlite3_errstr"), - errmsg: getFun("sqlite3_errmsg"), - erroff: getFun("sqlite3_error_offset"), - open: getFun("sqlite3_open_v2"), - close: getFun("sqlite3_close"), - closeZombie: getFun("sqlite3_close_v2"), - prepare: getFun("sqlite3_prepare_v3"), - finalize: getFun("sqlite3_finalize"), - reset: getFun("sqlite3_reset"), - step: getFun("sqlite3_step"), - exec: getFun("sqlite3_exec"), - clearBindings: getFun("sqlite3_clear_bindings"), - bindCount: getFun("sqlite3_bind_parameter_count"), - bindIndex: getFun("sqlite3_bind_parameter_index"), - bindName: getFun("sqlite3_bind_parameter_name"), - bindNull: getFun("sqlite3_bind_null"), - bindInteger: getFun("sqlite3_bind_int64"), - bindFloat: getFun("sqlite3_bind_double"), - bindText: getFun("sqlite3_bind_text64"), - bindBlob: getFun("sqlite3_bind_blob64"), - bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), - columnCount: getFun("sqlite3_column_count"), - columnName: getFun("sqlite3_column_name"), - columnType: getFun("sqlite3_column_type"), - columnInteger: getFun("sqlite3_column_int64"), - columnFloat: getFun("sqlite3_column_double"), - columnText: getFun("sqlite3_column_text"), - columnBlob: getFun("sqlite3_column_blob"), - columnBytes: getFun("sqlite3_column_bytes"), - autocommit: getFun("sqlite3_get_autocommit"), - lastRowid: getFun("sqlite3_last_insert_rowid"), - changes: getFun("sqlite3_changes64"), - blobOpen: getFun("sqlite3_blob_open"), - blobClose: getFun("sqlite3_blob_close"), - blobReopen: getFun("sqlite3_blob_reopen"), - blobBytes: getFun("sqlite3_blob_bytes"), - blobRead: getFun("sqlite3_blob_read"), - blobWrite: getFun("sqlite3_blob_write"), - backupInit: getFun("sqlite3_backup_init"), - backupStep: getFun("sqlite3_backup_step"), - backupFinish: getFun("sqlite3_backup_finish"), - backupRemaining: getFun("sqlite3_backup_remaining"), - backupPageCount: getFun("sqlite3_backup_pagecount"), - interrupt: getVal("sqlite3_interrupt_offset"), - } - return err -} - -type sqliteAPI struct { - free api.Function - malloc api.Function - destructor uint64 - errcode api.Function - errstr api.Function - errmsg api.Function - erroff api.Function - open api.Function - close api.Function - closeZombie api.Function - prepare api.Function - finalize api.Function - reset api.Function - step api.Function - exec api.Function - clearBindings api.Function - bindNull api.Function - bindCount api.Function - bindIndex api.Function - bindName api.Function - bindInteger api.Function - bindFloat api.Function - bindText api.Function - bindBlob api.Function - bindZeroBlob api.Function - columnCount api.Function - columnName api.Function - columnType api.Function - columnInteger api.Function - columnFloat api.Function - columnText api.Function - columnBlob api.Function - columnBytes api.Function - autocommit api.Function - lastRowid api.Function - changes api.Function - blobOpen api.Function - blobClose api.Function - blobReopen api.Function - blobBytes api.Function - blobRead api.Function - blobWrite api.Function - backupInit api.Function - backupStep api.Function - backupFinish api.Function - backupRemaining api.Function - backupPageCount api.Function - interrupt uint32 -} diff --git a/conn.go b/conn.go index dd5e8f2..85fc6cc 100644 --- a/conn.go +++ b/conn.go @@ -18,12 +18,9 @@ import ( // // https://www.sqlite.org/c3ref/sqlite3.html type Conn struct { - mod *module - ctx context.Context - api *sqliteAPI - mem *memory - handle uint32 + *module + handle uint32 arena arena interrupt context.Context waiter chan struct{} @@ -60,12 +57,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { } }() - c := &Conn{ - mod: mod, - ctx: mod.ctx, - api: &mod.api, - mem: &mod.mem, - } + c := &Conn{module: mod} c.arena = c.newArena(1024) c.handle, err = c.openDB(filename, flags) if err != nil { @@ -82,7 +74,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0) handle := c.mem.readUint32(connPtr) - if err := c.mod.error(r[0], handle); err != nil { + if err := c.module.error(r[0], handle); err != nil { c.closeDB(handle) return 0, err } @@ -100,7 +92,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { pragmaPtr := c.arena.string(pragmas.String()) r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0) - if err := c.mod.error(r[0], handle, pragmas.String()); err != nil { + if err := c.module.error(r[0], handle, pragmas.String()); err != nil { c.closeDB(handle) return 0, fmt.Errorf("sqlite3: invalid _pragma: %w", err) } @@ -110,7 +102,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { func (c *Conn) closeDB(handle uint32) { r := c.call(c.api.closeZombie, uint64(c.handle)) - if err := c.mod.error(r[0], handle); err != nil { + if err := c.module.error(r[0], handle); err != nil { panic(err) } } @@ -330,7 +322,7 @@ func (c *Conn) Pragma(str string) []string { } func (c *Conn) error(rc uint64, sql ...string) error { - return c.mod.error(rc, c.handle, sql...) + return c.module.error(rc, c.handle, sql...) } func (c *Conn) call(fn api.Function, params ...uint64) []uint64 { diff --git a/module.go b/module.go index f2713b3..efc524e 100644 --- a/module.go +++ b/module.go @@ -1,3 +1,4 @@ +// Package sqlite3 wraps the C SQLite API. package sqlite3 import ( @@ -25,9 +26,7 @@ var ( Path string // Path to load the binary from. ) -var sqlite3 sqlite3Runtime - -type sqlite3Runtime struct { +var sqlite3 struct { once sync.Once runtime wazero.Runtime compiled wazero.CompiledModule @@ -35,10 +34,10 @@ type sqlite3Runtime struct { err error } -func instantiateModule() (m *module, err error) { +func instantiateModule() (*module, error) { ctx := context.Background() - sqlite3.once.Do(func() { sqlite3.compileModule(ctx) }) + sqlite3.once.Do(compileModule) if sqlite3.err != nil { return nil, sqlite3.err } @@ -54,37 +53,27 @@ func instantiateModule() (m *module, err error) { if err != nil { return nil, err } - - module := &module{ - Module: mod, - ctx: ctx, - mem: memory{mod}, - } - - err = module.loadAPI() - if err != nil { - return nil, err - } - return module, nil + return newModule(mod) } -func (s *sqlite3Runtime) compileModule(ctx context.Context) { - s.runtime = wazero.NewRuntime(ctx) - vfsInstantiate(ctx, s.runtime) +func compileModule() { + ctx := context.Background() + sqlite3.runtime = wazero.NewRuntime(ctx) + vfsInstantiate(ctx, sqlite3.runtime) bin := Binary if bin == nil && Path != "" { - bin, s.err = os.ReadFile(Path) - if s.err != nil { + bin, sqlite3.err = os.ReadFile(Path) + if sqlite3.err != nil { return } } if bin == nil { - s.err = binaryErr + sqlite3.err = binaryErr return } - s.compiled, s.err = s.runtime.CompileModule(ctx, bin) + sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin) } type module struct { @@ -95,7 +84,87 @@ type module struct { api sqliteAPI } -func (c *module) error(rc uint64, handle uint32, sql ...string) error { +func newModule(mod api.Module) (m *module, err error) { + getFun := func(name string) api.Function { + f := m.ExportedFunction(name) + if f == nil { + err = noFuncErr + errorString(name) + return nil + } + return f + } + + getVal := func(name string) uint32 { + global := m.ExportedGlobal(name) + if global == nil { + err = noGlobalErr + errorString(name) + return 0 + } + return m.mem.readUint32(uint32(global.Get())) + } + + m = &module{ + Module: mod, + mem: memory{mod}, + ctx: context.Background(), + } + m.api = sqliteAPI{ + free: getFun("free"), + malloc: getFun("malloc"), + destructor: uint64(getVal("malloc_destructor")), + errcode: getFun("sqlite3_errcode"), + errstr: getFun("sqlite3_errstr"), + errmsg: getFun("sqlite3_errmsg"), + erroff: getFun("sqlite3_error_offset"), + open: getFun("sqlite3_open_v2"), + close: getFun("sqlite3_close"), + closeZombie: getFun("sqlite3_close_v2"), + prepare: getFun("sqlite3_prepare_v3"), + finalize: getFun("sqlite3_finalize"), + reset: getFun("sqlite3_reset"), + step: getFun("sqlite3_step"), + exec: getFun("sqlite3_exec"), + clearBindings: getFun("sqlite3_clear_bindings"), + bindCount: getFun("sqlite3_bind_parameter_count"), + bindIndex: getFun("sqlite3_bind_parameter_index"), + bindName: getFun("sqlite3_bind_parameter_name"), + bindNull: getFun("sqlite3_bind_null"), + bindInteger: getFun("sqlite3_bind_int64"), + bindFloat: getFun("sqlite3_bind_double"), + bindText: getFun("sqlite3_bind_text64"), + bindBlob: getFun("sqlite3_bind_blob64"), + bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), + columnCount: getFun("sqlite3_column_count"), + columnName: getFun("sqlite3_column_name"), + columnType: getFun("sqlite3_column_type"), + columnInteger: getFun("sqlite3_column_int64"), + columnFloat: getFun("sqlite3_column_double"), + columnText: getFun("sqlite3_column_text"), + columnBlob: getFun("sqlite3_column_blob"), + columnBytes: getFun("sqlite3_column_bytes"), + autocommit: getFun("sqlite3_get_autocommit"), + lastRowid: getFun("sqlite3_last_insert_rowid"), + changes: getFun("sqlite3_changes64"), + blobOpen: getFun("sqlite3_blob_open"), + blobClose: getFun("sqlite3_blob_close"), + blobReopen: getFun("sqlite3_blob_reopen"), + blobBytes: getFun("sqlite3_blob_bytes"), + blobRead: getFun("sqlite3_blob_read"), + blobWrite: getFun("sqlite3_blob_write"), + backupInit: getFun("sqlite3_backup_init"), + backupStep: getFun("sqlite3_backup_step"), + backupFinish: getFun("sqlite3_backup_finish"), + backupRemaining: getFun("sqlite3_backup_remaining"), + backupPageCount: getFun("sqlite3_backup_pagecount"), + interrupt: getVal("sqlite3_interrupt_offset"), + } + if err != nil { + m = nil + } + return +} + +func (m *module) error(rc uint64, handle uint32, sql ...string) error { if rc == _OK { return nil } @@ -108,18 +177,18 @@ func (c *module) error(rc uint64, handle uint32, sql ...string) error { var r []uint64 - r, _ = c.api.errstr.Call(c.ctx, rc) + r, _ = m.api.errstr.Call(m.ctx, rc) if r != nil { - err.str = c.mem.readString(uint32(r[0]), _MAX_STRING) + err.str = m.mem.readString(uint32(r[0]), _MAX_STRING) } - r, _ = c.api.errmsg.Call(c.ctx, uint64(handle)) + r, _ = m.api.errmsg.Call(m.ctx, uint64(handle)) if r != nil { - err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING) + err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING) } if sql != nil { - r, _ = c.api.erroff.Call(c.ctx, uint64(handle)) + r, _ = m.api.erroff.Call(m.ctx, uint64(handle)) if r != nil && r[0] != math.MaxUint32 { err.sql = sql[0][r[0]:] } @@ -131,3 +200,54 @@ func (c *module) error(rc uint64, handle uint32, sql ...string) error { } return &err } + +type sqliteAPI struct { + free api.Function + malloc api.Function + destructor uint64 + errcode api.Function + errstr api.Function + errmsg api.Function + erroff api.Function + open api.Function + close api.Function + closeZombie api.Function + prepare api.Function + finalize api.Function + reset api.Function + step api.Function + exec api.Function + clearBindings api.Function + bindNull api.Function + bindCount api.Function + bindIndex api.Function + bindName api.Function + bindInteger api.Function + bindFloat api.Function + bindText api.Function + bindBlob api.Function + bindZeroBlob api.Function + columnCount api.Function + columnName api.Function + columnType api.Function + columnInteger api.Function + columnFloat api.Function + columnText api.Function + columnBlob api.Function + columnBytes api.Function + autocommit api.Function + lastRowid api.Function + changes api.Function + blobOpen api.Function + blobClose api.Function + blobReopen api.Function + blobBytes api.Function + blobRead api.Function + blobWrite api.Function + backupInit api.Function + backupStep api.Function + backupFinish api.Function + backupRemaining api.Function + backupPageCount api.Function + interrupt uint32 +}