diff --git a/backup.go b/backup.go index 27a71a9..17efa03 100644 --- a/backup.go +++ b/backup.go @@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string if r == 0 { defer c.closeDB(other) r = c.call(c.api.errcode, uint64(dst)) - return nil, c.module.error(r, dst) + return nil, c.sqlite.error(r, dst) } return &Backup{ diff --git a/conn.go b/conn.go index ec168aa..95a1d33 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,7 @@ import ( // // https://www.sqlite.org/c3ref/sqlite3.html type Conn struct { - *module + *sqlite interrupt context.Context waiter chan struct{} @@ -50,7 +50,7 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { } func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { - mod, err := instantiateModule() + mod, err := instantiateSQLite() if err != nil { return nil, err } @@ -62,7 +62,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { } }() - c := &Conn{module: mod} + c := &Conn{sqlite: mod} c.arena = c.newArena(1024) c.handle, err = c.openDB(filename, flags) if err != nil { @@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0) handle := util.ReadUint32(c.mod, connPtr) - if err := c.module.error(r, handle); err != nil { + if err := c.sqlite.error(r, handle); err != nil { c.closeDB(handle) return 0, err } @@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { c.arena.reset() pragmaPtr := c.arena.string(pragmas.String()) r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0) - if err := c.module.error(r, handle, pragmas.String()); err != nil { + if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { if errors.Is(err, ERROR) { err = fmt.Errorf("sqlite3: invalid _pragma: %w", err) } @@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { func (c *Conn) closeDB(handle uint32) { r := c.call(c.api.closeZombie, uint64(handle)) - if err := c.module.error(r, handle); err != nil { + if err := c.sqlite.error(r, handle); err != nil { panic(err) } } @@ -143,7 +143,7 @@ func (c *Conn) Close() error { c.handle = 0 runtime.SetFinalizer(c, nil) - return c.module.close() + return c.close() } // Exec is a convenience function that allows an application to run @@ -319,7 +319,7 @@ func (c *Conn) Pragma(str string) ([]string, error) { } func (c *Conn) error(rc uint64, sql ...string) error { - return c.module.error(rc, c.handle, sql...) + return c.sqlite.error(rc, c.handle, sql...) } // DriverConn is implemented by the SQLite [database/sql] driver connection. diff --git a/context.go b/context.go index f281db2..1840641 100644 --- a/context.go +++ b/context.go @@ -12,7 +12,7 @@ import ( // // https://www.sqlite.org/c3ref/context.html type Context struct { - *module + *sqlite handle uint32 } diff --git a/func.go b/func.go index 92f9a1f..d58a144 100644 --- a/func.go +++ b/func.go @@ -96,57 +96,57 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK } func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - module := ctx.Value(moduleKey{}).(*module) - fn := callbackHandle(module, pCtx).(func(ctx Context, arg ...Value)) - fn(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...) + sqlite := ctx.Value(sqliteKey{}).(*sqlite) + fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value)) + fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - module := ctx.Value(moduleKey{}).(*module) - fn := callbackAggregate(module, pCtx, nil).(AggregateFunction) - fn.Step(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...) + sqlite := ctx.Value(sqliteKey{}).(*sqlite) + fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction) + fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) { var handle uint32 - module := ctx.Value(moduleKey{}).(*module) - fn := callbackAggregate(module, pCtx, &handle).(AggregateFunction) - fn.Value(Context{module, pCtx}) + sqlite := ctx.Value(sqliteKey{}).(*sqlite) + fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction) + fn.Value(Context{sqlite, pCtx}) if err := util.DelHandle(ctx, handle); err != nil { - Context{module, pCtx}.ResultError(err) + Context{sqlite, pCtx}.ResultError(err) } } func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) { - module := ctx.Value(moduleKey{}).(*module) - fn := callbackAggregate(module, pCtx, nil).(AggregateFunction) - fn.Value(Context{module, pCtx}) + sqlite := ctx.Value(sqliteKey{}).(*sqlite) + fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction) + fn.Value(Context{sqlite, pCtx}) } func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) { - module := ctx.Value(moduleKey{}).(*module) - fn := callbackAggregate(module, pCtx, nil).(WindowFunction) - fn.Inverse(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...) + sqlite := ctx.Value(sqliteKey{}).(*sqlite) + fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction) + fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...) } -func callbackHandle(module *module, pCtx uint32) any { - pApp := uint32(module.call(module.api.userData, uint64(pCtx))) - return util.GetHandle(module.ctx, pApp) +func callbackHandle(sqlite *sqlite, pCtx uint32) any { + pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx))) + return util.GetHandle(sqlite.ctx, pApp) } -func callbackAggregate(module *module, pCtx uint32, close *uint32) any { +func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any { // On close, we're getting rid of the handle. // Don't allocate space to store it. var size uint64 if close == nil { size = ptrlen } - ptr := uint32(module.call(module.api.aggregateCtx, uint64(pCtx), size)) + ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size)) // Try loading the handle, if we already have one, or want a new one. if ptr != 0 || size != 0 { - if handle := util.ReadUint32(module.mod, ptr); handle != 0 { - fn := util.GetHandle(module.ctx, handle) + if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 { + fn := util.GetHandle(sqlite.ctx, handle) if close != nil { *close = handle } @@ -157,19 +157,19 @@ func callbackAggregate(module *module, pCtx uint32, close *uint32) any { } // Create a new aggregate and store the handle. - fn := callbackHandle(module, pCtx).(func() AggregateFunction)() + fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)() if ptr != 0 { - util.WriteUint32(module.mod, ptr, util.AddHandle(module.ctx, fn)) + util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn)) } return fn } -func callbackArgs(module *module, nArg, pArg uint32) []Value { +func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value { args := make([]Value, nArg) for i := range args { args[i] = Value{ - module: module, - handle: util.ReadUint32(module.mod, pArg+ptrlen*uint32(i)), + sqlite: sqlite, + handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)), } } return args diff --git a/internal/util/handle.go b/internal/util/handle.go index be8c97f..20444a0 100644 --- a/internal/util/handle.go +++ b/internal/util/handle.go @@ -29,7 +29,7 @@ func (s *handleState) Close() (err error) { func GetHandle(ctx context.Context, id uint32) any { if id == 0 { - return nil + panic(NilErr) } s := ctx.Value(handleKey{}).(*handleState) return s.handles[^id] @@ -50,7 +50,7 @@ func DelHandle(ctx context.Context, id uint32) error { func AddHandle(ctx context.Context, a any) (id uint32) { if a == nil { - return 0 + panic(NilErr) } s := ctx.Value(handleKey{}).(*handleState) diff --git a/module.go b/module.go index 944b107..725f6da 100644 --- a/module.go +++ b/module.go @@ -25,58 +25,58 @@ var ( Path string // Path to load the binary from. ) -var sqlite3 struct { +var instance struct { runtime wazero.Runtime compiled wazero.CompiledModule err error once sync.Once } -func instantiateModule() (*module, error) { +func instantiateSQLite() (*sqlite, error) { ctx := context.Background() - sqlite3.once.Do(compileModule) - if sqlite3.err != nil { - return nil, sqlite3.err + instance.once.Do(compileSQLite) + if instance.err != nil { + return nil, instance.err } cfg := wazero.NewModuleConfig() - mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg) + mod, err := instance.runtime.InstantiateModule(ctx, instance.compiled, cfg) if err != nil { return nil, err } - return newModule(mod) + return newSQLite(mod) } -func compileModule() { +func compileSQLite() { ctx := context.Background() - sqlite3.runtime = wazero.NewRuntime(ctx) + instance.runtime = wazero.NewRuntime(ctx) - env := sqlite3.runtime.NewHostModuleBuilder("env") + env := instance.runtime.NewHostModuleBuilder("env") env = vfs.ExportHostFunctions(env) env = exportHostFunctions(env) - _, sqlite3.err = env.Instantiate(ctx) - if sqlite3.err != nil { + _, instance.err = env.Instantiate(ctx) + if instance.err != nil { return } bin := Binary if bin == nil && Path != "" { - bin, sqlite3.err = os.ReadFile(Path) - if sqlite3.err != nil { + bin, instance.err = os.ReadFile(Path) + if instance.err != nil { return } } if bin == nil { - sqlite3.err = util.BinaryErr + instance.err = util.BinaryErr return } - sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin) + instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin) } -type module struct { +type sqlite struct { ctx context.Context mod api.Module closer io.Closer @@ -84,13 +84,13 @@ type module struct { stack [8]uint64 } -type moduleKey struct{} +type sqliteKey struct{} -func newModule(mod api.Module) (m *module, err error) { - m = new(module) - m.ctx, m.closer = util.NewContext(context.Background()) - m.ctx = context.WithValue(m.ctx, moduleKey{}, m) - m.mod = mod +func newSQLite(mod api.Module) (sqlt *sqlite, err error) { + sqlt = new(sqlite) + sqlt.ctx, sqlt.closer = util.NewContext(context.Background()) + sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt) + sqlt.mod = mod getFun := func(name string) api.Function { f := mod.ExportedFunction(name) @@ -110,7 +110,7 @@ func newModule(mod api.Module) (m *module, err error) { return util.ReadUint32(mod, uint32(g.Get())) } - m.api = sqliteAPI{ + sqlt.api = sqliteAPI{ free: getFun("free"), malloc: getFun("malloc"), destructor: getVal("malloc_destructor"), @@ -184,16 +184,16 @@ func newModule(mod api.Module) (m *module, err error) { if err != nil { return nil, err } - return m, nil + return sqlt, nil } -func (m *module) close() error { - err := m.mod.Close(m.ctx) - m.closer.Close() +func (sqlt *sqlite) close() error { + err := sqlt.mod.Close(sqlt.ctx) + sqlt.closer.Close() return err } -func (m *module) error(rc uint64, handle uint32, sql ...string) error { +func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { if rc == _OK { return nil } @@ -204,16 +204,16 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error { panic(util.OOMErr) } - if r := m.call(m.api.errstr, rc); r != 0 { - err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING) + if r := sqlt.call(sqlt.api.errstr, rc); r != 0 { + err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) } - if r := m.call(m.api.errmsg, uint64(handle)); r != 0 { - err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING) + if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 { + err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) } if sql != nil { - if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 { + if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 { err.sql = sql[0][r:] } } @@ -225,60 +225,60 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error { return &err } -func (m *module) call(fn api.Function, params ...uint64) uint64 { - copy(m.stack[:], params) - err := fn.CallWithStack(m.ctx, m.stack[:]) +func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 { + copy(sqlt.stack[:], params) + err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) if err != nil { // The module closed or panicked; release resources. - m.closer.Close() + sqlt.closer.Close() panic(err) } - return m.stack[0] + return sqlt.stack[0] } -func (m *module) free(ptr uint32) { +func (sqlt *sqlite) free(ptr uint32) { if ptr == 0 { return } - m.call(m.api.free, uint64(ptr)) + sqlt.call(sqlt.api.free, uint64(ptr)) } -func (m *module) new(size uint64) uint32 { +func (sqlt *sqlite) new(size uint64) uint32 { if size > _MAX_ALLOCATION_SIZE { panic(util.OOMErr) } - ptr := uint32(m.call(m.api.malloc, size)) + ptr := uint32(sqlt.call(sqlt.api.malloc, size)) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (m *module) newBytes(b []byte) uint32 { +func (sqlt *sqlite) newBytes(b []byte) uint32 { if b == nil { return 0 } - ptr := m.new(uint64(len(b))) - util.WriteBytes(m.mod, ptr, b) + ptr := sqlt.new(uint64(len(b))) + util.WriteBytes(sqlt.mod, ptr, b) return ptr } -func (m *module) newString(s string) uint32 { - ptr := m.new(uint64(len(s) + 1)) - util.WriteString(m.mod, ptr, s) +func (sqlt *sqlite) newString(s string) uint32 { + ptr := sqlt.new(uint64(len(s) + 1)) + util.WriteString(sqlt.mod, ptr, s) return ptr } -func (m *module) newArena(size uint64) arena { +func (sqlt *sqlite) newArena(size uint64) arena { return arena{ - m: m, - base: m.new(size), + sqlt: sqlt, size: uint32(size), + base: sqlt.new(size), } } type arena struct { - m *module + sqlt *sqlite ptrs []uint32 base uint32 next uint32 @@ -286,17 +286,17 @@ type arena struct { } func (a *arena) free() { - if a.m == nil { + if a.sqlt == nil { return } a.reset() - a.m.free(a.base) - a.m = nil + a.sqlt.free(a.base) + a.sqlt = nil } func (a *arena) reset() { for _, ptr := range a.ptrs { - a.m.free(ptr) + a.sqlt.free(ptr) } a.ptrs = nil a.next = 0 @@ -308,7 +308,7 @@ func (a *arena) new(size uint64) uint32 { a.next += uint32(size) return ptr } - ptr := a.m.new(size) + ptr := a.sqlt.new(size) a.ptrs = append(a.ptrs, ptr) return ptr } @@ -318,13 +318,13 @@ func (a *arena) bytes(b []byte) uint32 { return 0 } ptr := a.new(uint64(len(b))) - util.WriteBytes(a.m.mod, ptr, b) + util.WriteBytes(a.sqlt.mod, ptr, b) return ptr } func (a *arena) string(s string) uint32 { ptr := a.new(uint64(len(s) + 1)) - util.WriteString(a.m.mod, ptr, s) + util.WriteString(a.sqlt.mod, ptr, s) return ptr } diff --git a/module_test.go b/module_test.go index b6eb28a..e03cd30 100644 --- a/module_test.go +++ b/module_test.go @@ -15,7 +15,7 @@ func init() { func TestConn_error_OOM(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -29,7 +29,7 @@ func TestConn_error_OOM(t *testing.T) { func TestConn_call_closed(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -43,7 +43,7 @@ func TestConn_call_closed(t *testing.T) { func TestConn_new(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -66,7 +66,7 @@ func TestConn_new(t *testing.T) { func TestConn_newArena(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -111,7 +111,7 @@ func TestConn_newArena(t *testing.T) { func TestConn_newBytes(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -137,7 +137,7 @@ func TestConn_newBytes(t *testing.T) { func TestConn_newString(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -163,7 +163,7 @@ func TestConn_newString(t *testing.T) { func TestConn_getString(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } @@ -204,7 +204,7 @@ func TestConn_getString(t *testing.T) { func TestConn_free(t *testing.T) { t.Parallel() - m, err := instantiateModule() + m, err := instantiateSQLite() if err != nil { t.Fatal(err) } diff --git a/value.go b/value.go index a542412..aed1056 100644 --- a/value.go +++ b/value.go @@ -11,7 +11,7 @@ import ( // // https://www.sqlite.org/c3ref/value.html type Value struct { - *module + *sqlite handle uint32 } diff --git a/vfs/vfs.go b/vfs/vfs.go index f671487..7aef26d 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -156,6 +156,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla file, flags, err = vfs.Open(path, flags) } + if err != nil { + return vfsErrorCode(err, _CANTOPEN) + } + if file, ok := file.(FilePowersafeOverwrite); ok { if !parsed { params = vfsURIParameters(ctx, mod, zPath, flags) @@ -165,14 +169,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla } } - if err != nil { - return vfsErrorCode(err, _CANTOPEN) - } - - vfsFileRegister(ctx, mod, pFile, file) if pOutFlags != 0 { util.WriteUint32(mod, pOutFlags, uint32(flags)) } + vfsFileRegister(ctx, mod, pFile, file) return _OK }