Towards shared modules: refactor.

This commit is contained in:
Nuno Cruces
2023-03-06 18:28:50 +00:00
parent 1ebdc1aa93
commit c1263d4f33
3 changed files with 157 additions and 172 deletions

127
api.go
View File

@@ -1,127 +0,0 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import "github.com/tetratelabs/wazero/api"
func (module *module) loadAPI() (err error) {
getFun := func(name string) api.Function {
f := module.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := module.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return module.mem.readUint32(uint32(global.Get()))
}
module.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
interrupt: getVal("sqlite3_interrupt_offset"),
}
return err
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
interrupt uint32
}

22
conn.go
View File

@@ -18,12 +18,9 @@ import (
//
// https://www.sqlite.org/c3ref/sqlite3.html
type Conn struct {
mod *module
ctx context.Context
api *sqliteAPI
mem *memory
handle uint32
*module
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
@@ -60,12 +57,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
}
}()
c := &Conn{
mod: mod,
ctx: mod.ctx,
api: &mod.api,
mem: &mod.mem,
}
c := &Conn{module: mod}
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags)
if err != nil {
@@ -82,7 +74,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := c.mem.readUint32(connPtr)
if err := c.mod.error(r[0], handle); err != nil {
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -100,7 +92,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.mod.error(r[0], handle, pragmas.String()); err != nil {
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
c.closeDB(handle)
return 0, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -110,7 +102,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(c.handle))
if err := c.mod.error(r[0], handle); err != nil {
if err := c.module.error(r[0], handle); err != nil {
panic(err)
}
}
@@ -330,7 +322,7 @@ func (c *Conn) Pragma(str string) []string {
}
func (c *Conn) error(rc uint64, sql ...string) error {
return c.mod.error(rc, c.handle, sql...)
return c.module.error(rc, c.handle, sql...)
}
func (c *Conn) call(fn api.Function, params ...uint64) []uint64 {

180
module.go
View File

@@ -1,3 +1,4 @@
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
@@ -25,9 +26,7 @@ var (
Path string // Path to load the binary from.
)
var sqlite3 sqlite3Runtime
type sqlite3Runtime struct {
var sqlite3 struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
@@ -35,10 +34,10 @@ type sqlite3Runtime struct {
err error
}
func instantiateModule() (m *module, err error) {
func instantiateModule() (*module, error) {
ctx := context.Background()
sqlite3.once.Do(func() { sqlite3.compileModule(ctx) })
sqlite3.once.Do(compileModule)
if sqlite3.err != nil {
return nil, sqlite3.err
}
@@ -54,37 +53,27 @@ func instantiateModule() (m *module, err error) {
if err != nil {
return nil, err
}
module := &module{
Module: mod,
ctx: ctx,
mem: memory{mod},
}
err = module.loadAPI()
if err != nil {
return nil, err
}
return module, nil
return newModule(mod)
}
func (s *sqlite3Runtime) compileModule(ctx context.Context) {
s.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, s.runtime)
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, sqlite3.runtime)
bin := Binary
if bin == nil && Path != "" {
bin, s.err = os.ReadFile(Path)
if s.err != nil {
bin, sqlite3.err = os.ReadFile(Path)
if sqlite3.err != nil {
return
}
}
if bin == nil {
s.err = binaryErr
sqlite3.err = binaryErr
return
}
s.compiled, s.err = s.runtime.CompileModule(ctx, bin)
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
}
type module struct {
@@ -95,7 +84,87 @@ type module struct {
api sqliteAPI
}
func (c *module) error(rc uint64, handle uint32, sql ...string) error {
func newModule(mod api.Module) (m *module, err error) {
getFun := func(name string) api.Function {
f := m.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := m.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
return 0
}
return m.mem.readUint32(uint32(global.Get()))
}
m = &module{
Module: mod,
mem: memory{mod},
ctx: context.Background(),
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
erroff: getFun("sqlite3_error_offset"),
open: getFun("sqlite3_open_v2"),
close: getFun("sqlite3_close"),
closeZombie: getFun("sqlite3_close_v2"),
prepare: getFun("sqlite3_prepare_v3"),
finalize: getFun("sqlite3_finalize"),
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
bindName: getFun("sqlite3_bind_parameter_name"),
bindNull: getFun("sqlite3_bind_null"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
blobBytes: getFun("sqlite3_blob_bytes"),
blobRead: getFun("sqlite3_blob_read"),
blobWrite: getFun("sqlite3_blob_write"),
backupInit: getFun("sqlite3_backup_init"),
backupStep: getFun("sqlite3_backup_step"),
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
interrupt: getVal("sqlite3_interrupt_offset"),
}
if err != nil {
m = nil
}
return
}
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
if rc == _OK {
return nil
}
@@ -108,18 +177,18 @@ func (c *module) error(rc uint64, handle uint32, sql ...string) error {
var r []uint64
r, _ = c.api.errstr.Call(c.ctx, rc)
r, _ = m.api.errstr.Call(m.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), _MAX_STRING)
err.str = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(handle))
r, _ = m.api.errmsg.Call(m.ctx, uint64(handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING)
err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING)
}
if sql != nil {
r, _ = c.api.erroff.Call(c.ctx, uint64(handle))
r, _ = m.api.erroff.Call(m.ctx, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
@@ -131,3 +200,54 @@ func (c *module) error(rc uint64, handle uint32, sql ...string) error {
}
return &err
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
erroff api.Function
open api.Function
close api.Function
closeZombie api.Function
prepare api.Function
finalize api.Function
reset api.Function
step api.Function
exec api.Function
clearBindings api.Function
bindNull api.Function
bindCount api.Function
bindIndex api.Function
bindName api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
blobBytes api.Function
blobRead api.Function
blobWrite api.Function
backupInit api.Function
backupStep api.Function
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
interrupt uint32
}