mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Refactor.
This commit is contained in:
405
sqlite.go
Normal file
405
sqlite.go
Normal file
@@ -0,0 +1,405 @@
|
||||
// Package sqlite3 wraps the C SQLite API.
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Configure SQLite WASM.
|
||||
//
|
||||
// Importing package embed initializes these
|
||||
// with an appropriate build of SQLite:
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/embed"
|
||||
var (
|
||||
Binary []byte // WASM binary to load.
|
||||
Path string // Path to load the binary from.
|
||||
)
|
||||
|
||||
var instance struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func instantiateSQLite() (*sqlite, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
return nil, instance.err
|
||||
}
|
||||
|
||||
cfg := wazero.NewModuleConfig()
|
||||
|
||||
mod, err := instance.runtime.InstantiateModule(ctx, instance.compiled, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSQLite(mod)
|
||||
}
|
||||
|
||||
func compileSQLite() {
|
||||
ctx := context.Background()
|
||||
instance.runtime = wazero.NewRuntime(ctx)
|
||||
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, instance.err = os.ReadFile(Path)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
instance.err = util.BinaryErr
|
||||
return
|
||||
}
|
||||
|
||||
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
type sqlite struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
closer io.Closer
|
||||
api sqliteAPI
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
type sqliteKey struct{}
|
||||
|
||||
func newSQLite(mod api.Module) (sqlt *sqlite, err error) {
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx, sqlt.closer = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
sqlt.mod = mod
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := mod.ExportedFunction(name)
|
||||
if f == nil {
|
||||
err = util.NoFuncErr + util.ErrorString(name)
|
||||
return nil
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
getVal := func(name string) uint32 {
|
||||
g := mod.ExportedGlobal(name)
|
||||
if g == nil {
|
||||
err = util.NoGlobalErr + util.ErrorString(name)
|
||||
return 0
|
||||
}
|
||||
return util.ReadUint32(mod, uint32(g.Get()))
|
||||
}
|
||||
|
||||
sqlt.api = sqliteAPI{
|
||||
free: getFun("free"),
|
||||
malloc: getFun("malloc"),
|
||||
destructor: getVal("malloc_destructor"),
|
||||
errcode: getFun("sqlite3_errcode"),
|
||||
errstr: getFun("sqlite3_errstr"),
|
||||
errmsg: getFun("sqlite3_errmsg"),
|
||||
erroff: getFun("sqlite3_error_offset"),
|
||||
open: getFun("sqlite3_open_v2"),
|
||||
close: getFun("sqlite3_close"),
|
||||
closeZombie: getFun("sqlite3_close_v2"),
|
||||
prepare: getFun("sqlite3_prepare_v3"),
|
||||
finalize: getFun("sqlite3_finalize"),
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
bindName: getFun("sqlite3_bind_parameter_name"),
|
||||
bindNull: getFun("sqlite3_bind_null"),
|
||||
bindInteger: getFun("sqlite3_bind_int64"),
|
||||
bindFloat: getFun("sqlite3_bind_double"),
|
||||
bindText: getFun("sqlite3_bind_text64"),
|
||||
bindBlob: getFun("sqlite3_bind_blob64"),
|
||||
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
|
||||
columnCount: getFun("sqlite3_column_count"),
|
||||
columnName: getFun("sqlite3_column_name"),
|
||||
columnType: getFun("sqlite3_column_type"),
|
||||
columnInteger: getFun("sqlite3_column_int64"),
|
||||
columnFloat: getFun("sqlite3_column_double"),
|
||||
columnText: getFun("sqlite3_column_text"),
|
||||
columnBlob: getFun("sqlite3_column_blob"),
|
||||
columnBytes: getFun("sqlite3_column_bytes"),
|
||||
blobOpen: getFun("sqlite3_blob_open"),
|
||||
blobClose: getFun("sqlite3_blob_close"),
|
||||
blobReopen: getFun("sqlite3_blob_reopen"),
|
||||
blobBytes: getFun("sqlite3_blob_bytes"),
|
||||
blobRead: getFun("sqlite3_blob_read"),
|
||||
blobWrite: getFun("sqlite3_blob_write"),
|
||||
backupInit: getFun("sqlite3_backup_init"),
|
||||
backupStep: getFun("sqlite3_backup_step"),
|
||||
backupFinish: getFun("sqlite3_backup_finish"),
|
||||
backupRemaining: getFun("sqlite3_backup_remaining"),
|
||||
backupPageCount: getFun("sqlite3_backup_pagecount"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
createCollation: getFun("sqlite3_create_collation_go"),
|
||||
createFunction: getFun("sqlite3_create_function_go"),
|
||||
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
||||
createWindow: getFun("sqlite3_create_window_function_go"),
|
||||
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
||||
getAuxData: getFun("sqlite3_get_auxdata"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sqlt, nil
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) close() error {
|
||||
err := sqlt.mod.Close(sqlt.ctx)
|
||||
sqlt.closer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := Error{code: rc}
|
||||
|
||||
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
|
||||
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
|
||||
switch err.msg {
|
||||
case err.str, "not an error":
|
||||
err.msg = ""
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(sqlt.stack[:], params)
|
||||
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
||||
if err != nil {
|
||||
// The module closed or panicked; release resources.
|
||||
sqlt.closer.Close()
|
||||
panic(err)
|
||||
}
|
||||
return sqlt.stack[0]
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
sqlt.call(sqlt.api.free, uint64(ptr))
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
if size > _MAX_ALLOCATION_SIZE {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
util.WriteBytes(sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newString(s string) uint32 {
|
||||
ptr := sqlt.new(uint64(len(s) + 1))
|
||||
util.WriteString(sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (sqlt *sqlite) newArena(size uint64) arena {
|
||||
return arena{
|
||||
sqlt: sqlt,
|
||||
size: uint32(size),
|
||||
base: sqlt.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
sqlt *sqlite
|
||||
ptrs []uint32
|
||||
base uint32
|
||||
next uint32
|
||||
size uint32
|
||||
}
|
||||
|
||||
func (a *arena) free() {
|
||||
if a.sqlt == nil {
|
||||
return
|
||||
}
|
||||
a.reset()
|
||||
a.sqlt.free(a.base)
|
||||
a.sqlt = nil
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.sqlt.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
}
|
||||
|
||||
func (a *arena) new(size uint64) uint32 {
|
||||
if size <= uint64(a.size-a.next) {
|
||||
ptr := a.base + a.next
|
||||
a.next += uint32(size)
|
||||
return ptr
|
||||
}
|
||||
ptr := a.sqlt.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) bytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := a.new(uint64(len(b)))
|
||||
util.WriteBytes(a.sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint64(len(s) + 1))
|
||||
util.WriteString(a.sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
type sqliteAPI struct {
|
||||
free api.Function
|
||||
malloc api.Function
|
||||
errcode api.Function
|
||||
errstr api.Function
|
||||
errmsg api.Function
|
||||
erroff api.Function
|
||||
open api.Function
|
||||
close api.Function
|
||||
closeZombie api.Function
|
||||
prepare api.Function
|
||||
finalize api.Function
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindNull api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
bindBlob api.Function
|
||||
bindZeroBlob api.Function
|
||||
columnCount api.Function
|
||||
columnName api.Function
|
||||
columnType api.Function
|
||||
columnInteger api.Function
|
||||
columnFloat api.Function
|
||||
columnText api.Function
|
||||
columnBlob api.Function
|
||||
columnBytes api.Function
|
||||
blobOpen api.Function
|
||||
blobClose api.Function
|
||||
blobReopen api.Function
|
||||
blobBytes api.Function
|
||||
blobRead api.Function
|
||||
blobWrite api.Function
|
||||
backupInit api.Function
|
||||
backupStep api.Function
|
||||
backupFinish api.Function
|
||||
backupRemaining api.Function
|
||||
backupPageCount api.Function
|
||||
changes api.Function
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
createAggregate api.Function
|
||||
createWindow api.Function
|
||||
aggregateCtx api.Function
|
||||
userData api.Function
|
||||
setAuxData api.Function
|
||||
getAuxData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
Reference in New Issue
Block a user