diff --git a/api.go b/api.go index 987e8fa..0649ecb 100644 --- a/api.go +++ b/api.go @@ -1,13 +1,9 @@ // Package sqlite3 wraps the C SQLite API. package sqlite3 -import ( - "context" +import "github.com/tetratelabs/wazero/api" - "github.com/tetratelabs/wazero/api" -) - -func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { +func (module *module) loadAPI() (err error) { getFun := func(name string) api.Function { f := module.ExportedFunction(name) if f == nil { @@ -23,13 +19,11 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { err = noGlobalErr + errorString(name) return 0 } - return memory{module}.readUint32(uint32(global.Get())) + return module.mem.readUint32(uint32(global.Get())) } - c := Conn{ - ctx: ctx, - mem: memory{module}, - api: sqliteAPI{ + { + module.api = sqliteAPI{ free: getFun("free"), malloc: getFun("malloc"), destructor: uint64(getVal("malloc_destructor")), @@ -72,12 +66,9 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { blobRead: getFun("sqlite3_blob_read"), blobWrite: getFun("sqlite3_blob_write"), interrupt: getVal("sqlite3_interrupt_offset"), - }, + } } - if err != nil { - return nil, err - } - return &c, nil + return err } type sqliteAPI struct { diff --git a/compile.go b/compile.go index e454625..1efb23f 100644 --- a/compile.go +++ b/compile.go @@ -34,18 +34,37 @@ type sqlite3Runtime struct { err error } -func (s *sqlite3Runtime) instantiateModule(ctx context.Context) (api.Module, error) { - s.once.Do(func() { s.compileModule(ctx) }) - if s.err != nil { - return nil, s.err +func instantiateModule() (m *module, err error) { + ctx := context.Background() + + sqlite3.once.Do(func() { sqlite3.compileModule(ctx) }) + if sqlite3.err != nil { + return nil, sqlite3.err } - cfg := wazero.NewModuleConfig(). - WithName("sqlite3-" + strconv.FormatUint(s.instances.Add(1), 10)). + name := "sqlite3-" + strconv.FormatUint(sqlite3.instances.Add(1), 10) + + cfg := wazero.NewModuleConfig().WithName(name). WithSysWalltime().WithSysNanotime().WithSysNanosleep(). WithOsyield(runtime.Gosched). WithRandSource(rand.Reader) - return s.runtime.InstantiateModule(ctx, s.compiled, cfg) + + mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg) + if err != nil { + return nil, err + } + + module := &module{ + Module: mod, + ctx: ctx, + mem: memory{mod}, + } + + err = module.loadAPI() + if err != nil { + return nil, err + } + return module, nil } func (s *sqlite3Runtime) compileModule(ctx context.Context) { @@ -66,3 +85,11 @@ func (s *sqlite3Runtime) compileModule(ctx context.Context) { s.compiled, s.err = s.runtime.CompileModule(ctx, bin) } + +type module struct { + api.Module + + ctx context.Context + mem memory + api sqliteAPI +} diff --git a/conn.go b/conn.go index 361313c..80fb71a 100644 --- a/conn.go +++ b/conn.go @@ -19,9 +19,10 @@ import ( // // https://www.sqlite.org/c3ref/sqlite3.html type Conn struct { + mod *module ctx context.Context - api sqliteAPI - mem memory + api *sqliteAPI + mem *memory handle uint32 arena arena @@ -48,21 +49,23 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { func openFlags(filename string, flags OpenFlag) (conn *Conn, err error) { ctx := context.Background() - module, err := sqlite3.instantiateModule(ctx) + mod, err := instantiateModule() if err != nil { return nil, err } defer func() { if conn == nil { - module.Close(ctx) + mod.Close(ctx) } else { runtime.SetFinalizer(conn, finalizer[Conn](3)) } }() - c, err := newConn(ctx, module) - if err != nil { - return nil, err + c := &Conn{ + mod: mod, + ctx: mod.ctx, + api: &mod.api, + mem: &mod.mem, } c.arena = c.newArena(1024)