Refactor.

This commit is contained in:
Nuno Cruces
2023-07-03 17:21:35 +01:00
parent f6d7c5e9c5
commit 6f7f776488
9 changed files with 114 additions and 114 deletions

View File

@@ -77,7 +77,7 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r, dst)
return nil, c.sqlite.error(r, dst)
}
return &Backup{

16
conn.go
View File

@@ -19,7 +19,7 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
*module
*sqlite
interrupt context.Context
waiter chan struct{}
@@ -50,7 +50,7 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
mod, err := instantiateModule()
mod, err := instantiateSQLite()
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
}
}()
c := &Conn{module: mod}
c := &Conn{sqlite: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r, handle); err != nil {
if err := c.sqlite.error(r, handle); err != nil {
panic(err)
}
}
@@ -143,7 +143,7 @@ func (c *Conn) Close() error {
c.handle = 0
runtime.SetFinalizer(c, nil)
return c.module.close()
return c.close()
}
// Exec is a convenience function that allows an application to run
@@ -319,7 +319,7 @@ func (c *Conn) Pragma(str string) ([]string, error) {
}
func (c *Conn) error(rc uint64, sql ...string) error {
return c.module.error(rc, c.handle, sql...)
return c.sqlite.error(rc, c.handle, sql...)
}
// DriverConn is implemented by the SQLite [database/sql] driver connection.

View File

@@ -12,7 +12,7 @@ import (
//
// https://www.sqlite.org/c3ref/context.html
type Context struct {
*module
*sqlite
handle uint32
}

56
func.go
View File

@@ -96,57 +96,57 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
module := ctx.Value(moduleKey{}).(*module)
fn := callbackHandle(module, pCtx).(func(ctx Context, arg ...Value))
fn(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...)
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
module := ctx.Value(moduleKey{}).(*module)
fn := callbackAggregate(module, pCtx, nil).(AggregateFunction)
fn.Step(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...)
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
module := ctx.Value(moduleKey{}).(*module)
fn := callbackAggregate(module, pCtx, &handle).(AggregateFunction)
fn.Value(Context{module, pCtx})
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
Context{module, pCtx}.ResultError(err)
Context{sqlite, pCtx}.ResultError(err)
}
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
module := ctx.Value(moduleKey{}).(*module)
fn := callbackAggregate(module, pCtx, nil).(AggregateFunction)
fn.Value(Context{module, pCtx})
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
module := ctx.Value(moduleKey{}).(*module)
fn := callbackAggregate(module, pCtx, nil).(WindowFunction)
fn.Inverse(Context{module, pCtx}, callbackArgs(module, nArg, pArg)...)
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackHandle(module *module, pCtx uint32) any {
pApp := uint32(module.call(module.api.userData, uint64(pCtx)))
return util.GetHandle(module.ctx, pApp)
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
return util.GetHandle(sqlite.ctx, pApp)
}
func callbackAggregate(module *module, pCtx uint32, close *uint32) any {
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
// On close, we're getting rid of the handle.
// Don't allocate space to store it.
var size uint64
if close == nil {
size = ptrlen
}
ptr := uint32(module.call(module.api.aggregateCtx, uint64(pCtx), size))
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
// Try loading the handle, if we already have one, or want a new one.
if ptr != 0 || size != 0 {
if handle := util.ReadUint32(module.mod, ptr); handle != 0 {
fn := util.GetHandle(module.ctx, handle)
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
fn := util.GetHandle(sqlite.ctx, handle)
if close != nil {
*close = handle
}
@@ -157,19 +157,19 @@ func callbackAggregate(module *module, pCtx uint32, close *uint32) any {
}
// Create a new aggregate and store the handle.
fn := callbackHandle(module, pCtx).(func() AggregateFunction)()
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
if ptr != 0 {
util.WriteUint32(module.mod, ptr, util.AddHandle(module.ctx, fn))
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
}
return fn
}
func callbackArgs(module *module, nArg, pArg uint32) []Value {
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
args := make([]Value, nArg)
for i := range args {
args[i] = Value{
module: module,
handle: util.ReadUint32(module.mod, pArg+ptrlen*uint32(i)),
sqlite: sqlite,
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
}
}
return args

View File

@@ -29,7 +29,7 @@ func (s *handleState) Close() (err error) {
func GetHandle(ctx context.Context, id uint32) any {
if id == 0 {
return nil
panic(NilErr)
}
s := ctx.Value(handleKey{}).(*handleState)
return s.handles[^id]
@@ -50,7 +50,7 @@ func DelHandle(ctx context.Context, id uint32) error {
func AddHandle(ctx context.Context, a any) (id uint32) {
if a == nil {
return 0
panic(NilErr)
}
s := ctx.Value(handleKey{}).(*handleState)

120
module.go
View File

@@ -25,58 +25,58 @@ var (
Path string // Path to load the binary from.
)
var sqlite3 struct {
var instance struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
func instantiateSQLite() (*sqlite, error) {
ctx := context.Background()
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
instance.once.Do(compileSQLite)
if instance.err != nil {
return nil, instance.err
}
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
mod, err := instance.runtime.InstantiateModule(ctx, instance.compiled, cfg)
if err != nil {
return nil, err
}
return newModule(mod)
return newSQLite(mod)
}
func compileModule() {
func compileSQLite() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
instance.runtime = wazero.NewRuntime(ctx)
env := sqlite3.runtime.NewHostModuleBuilder("env")
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
bin, instance.err = os.ReadFile(Path)
if instance.err != nil {
return
}
}
if bin == nil {
sqlite3.err = util.BinaryErr
instance.err = util.BinaryErr
return
}
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
}
type module struct {
type sqlite struct {
ctx context.Context
mod api.Module
closer io.Closer
@@ -84,13 +84,13 @@ type module struct {
stack [8]uint64
}
type moduleKey struct{}
type sqliteKey struct{}
func newModule(mod api.Module) (m *module, err error) {
m = new(module)
m.ctx, m.closer = util.NewContext(context.Background())
m.ctx = context.WithValue(m.ctx, moduleKey{}, m)
m.mod = mod
func newSQLite(mod api.Module) (sqlt *sqlite, err error) {
sqlt = new(sqlite)
sqlt.ctx, sqlt.closer = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod = mod
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
@@ -110,7 +110,7 @@ func newModule(mod api.Module) (m *module, err error) {
return util.ReadUint32(mod, uint32(g.Get()))
}
m.api = sqliteAPI{
sqlt.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: getVal("malloc_destructor"),
@@ -184,16 +184,16 @@ func newModule(mod api.Module) (m *module, err error) {
if err != nil {
return nil, err
}
return m, nil
return sqlt, nil
}
func (m *module) close() error {
err := m.mod.Close(m.ctx)
m.closer.Close()
func (sqlt *sqlite) close() error {
err := sqlt.mod.Close(sqlt.ctx)
sqlt.closer.Close()
return err
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
@@ -204,16 +204,16 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -225,60 +225,60 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.stack[:], params)
err := fn.CallWithStack(m.ctx, m.stack[:])
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
copy(sqlt.stack[:], params)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
// The module closed or panicked; release resources.
m.closer.Close()
sqlt.closer.Close()
panic(err)
}
return m.stack[0]
return sqlt.stack[0]
}
func (m *module) free(ptr uint32) {
func (sqlt *sqlite) free(ptr uint32) {
if ptr == 0 {
return
}
m.call(m.api.free, uint64(ptr))
sqlt.call(sqlt.api.free, uint64(ptr))
}
func (m *module) new(size uint64) uint32 {
func (sqlt *sqlite) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
ptr := uint32(m.call(m.api.malloc, size))
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (m *module) newBytes(b []byte) uint32 {
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := m.new(uint64(len(b)))
util.WriteBytes(m.mod, ptr, b)
ptr := sqlt.new(uint64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
util.WriteString(m.mod, ptr, s)
func (sqlt *sqlite) newString(s string) uint32 {
ptr := sqlt.new(uint64(len(s) + 1))
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
func (m *module) newArena(size uint64) arena {
func (sqlt *sqlite) newArena(size uint64) arena {
return arena{
m: m,
base: m.new(size),
sqlt: sqlt,
size: uint32(size),
base: sqlt.new(size),
}
}
type arena struct {
m *module
sqlt *sqlite
ptrs []uint32
base uint32
next uint32
@@ -286,17 +286,17 @@ type arena struct {
}
func (a *arena) free() {
if a.m == nil {
if a.sqlt == nil {
return
}
a.reset()
a.m.free(a.base)
a.m = nil
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) reset() {
for _, ptr := range a.ptrs {
a.m.free(ptr)
a.sqlt.free(ptr)
}
a.ptrs = nil
a.next = 0
@@ -308,7 +308,7 @@ func (a *arena) new(size uint64) uint32 {
a.next += uint32(size)
return ptr
}
ptr := a.m.new(size)
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr
}
@@ -318,13 +318,13 @@ func (a *arena) bytes(b []byte) uint32 {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
util.WriteString(a.m.mod, ptr, s)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}

View File

@@ -15,7 +15,7 @@ func init() {
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -29,7 +29,7 @@ func TestConn_error_OOM(t *testing.T) {
func TestConn_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func TestConn_call_closed(t *testing.T) {
func TestConn_new(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -66,7 +66,7 @@ func TestConn_new(t *testing.T) {
func TestConn_newArena(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -111,7 +111,7 @@ func TestConn_newArena(t *testing.T) {
func TestConn_newBytes(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -137,7 +137,7 @@ func TestConn_newBytes(t *testing.T) {
func TestConn_newString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -163,7 +163,7 @@ func TestConn_newString(t *testing.T) {
func TestConn_getString(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}
@@ -204,7 +204,7 @@ func TestConn_getString(t *testing.T) {
func TestConn_free(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
m, err := instantiateSQLite()
if err != nil {
t.Fatal(err)
}

View File

@@ -11,7 +11,7 @@ import (
//
// https://www.sqlite.org/c3ref/value.html
type Value struct {
*module
*sqlite
handle uint32
}

View File

@@ -156,6 +156,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
file, flags, err = vfs.Open(path, flags)
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if !parsed {
params = vfsURIParameters(ctx, mod, zPath, flags)
@@ -165,14 +169,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
}
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
vfsFileRegister(ctx, mod, pFile, file)
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
vfsFileRegister(ctx, mod, pFile, file)
return _OK
}