mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Interrupts: avoid goroutine.
This commit is contained in:
73
conn.go
73
conn.go
@@ -7,10 +7,9 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// Conn is a database connection handle.
|
||||
@@ -21,7 +20,6 @@ type Conn struct {
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
pending *Stmt
|
||||
arena arena
|
||||
|
||||
@@ -48,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
return newConn(filename, flags)
|
||||
}
|
||||
|
||||
type connKey struct{}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
@@ -63,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.ctx = context.WithValue(c.ctx, connKey{}, c)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -131,7 +132,6 @@ func (c *Conn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.SetInterrupt(context.Background())
|
||||
c.pending.Close()
|
||||
c.pending = nil
|
||||
|
||||
@@ -244,65 +244,40 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
<-c.waiter // Wait for it to finish.
|
||||
c.waiter = nil
|
||||
}
|
||||
// Reset the pending statement.
|
||||
if c.pending != nil {
|
||||
// An uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
} else {
|
||||
c.pending.Reset()
|
||||
}
|
||||
|
||||
old = c.interrupt
|
||||
c.interrupt = ctx
|
||||
// Remove the handler if the context can't be canceled.
|
||||
if ctx == nil || ctx.Done() == nil {
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 0)
|
||||
return old
|
||||
}
|
||||
|
||||
// Creating an uncompleted SQL statement prevents SQLite from ignoring
|
||||
// an interrupt that comes before any other statements are started.
|
||||
if c.pending == nil {
|
||||
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
|
||||
}
|
||||
c.pending.Step()
|
||||
|
||||
// Don't create the goroutine if we're already interrupted.
|
||||
// This happens frequently while restoring to a previously interrupted state.
|
||||
if c.checkInterrupt() {
|
||||
return old
|
||||
}
|
||||
|
||||
waiter := make(chan struct{})
|
||||
c.waiter = waiter
|
||||
go func() {
|
||||
select {
|
||||
case <-waiter: // Waiter was cancelled.
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
}
|
||||
|
||||
// Signal that the waiter has finished.
|
||||
waiter <- struct{}{}
|
||||
}()
|
||||
c.call(c.api.progressHandler, uint64(c.handle), 100)
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
|
||||
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Conn) checkInterrupt() {
|
||||
if c.interrupt != nil && c.interrupt.Err() != nil {
|
||||
c.call(c.api.interrupt, uint64(c.handle))
|
||||
}
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pragma executes a PRAGMA statement and returns any results.
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
@@ -225,7 +226,13 @@ func (c *conn) Commit() error {
|
||||
}
|
||||
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
err := c.Conn.Exec(c.txRollback)
|
||||
if errors.Is(err, sqlite3.INTERRUPT) {
|
||||
old := c.Conn.SetInterrupt(context.Background())
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
err = c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
|
||||
@@ -13,6 +13,8 @@ sqlite3_finalize
|
||||
sqlite3_reset
|
||||
sqlite3_step
|
||||
sqlite3_exec
|
||||
sqlite3_interrupt
|
||||
sqlite3_progress_handler_go
|
||||
sqlite3_clear_bindings
|
||||
sqlite3_bind_parameter_count
|
||||
sqlite3_bind_parameter_index
|
||||
|
||||
Binary file not shown.
25
func.go
25
func.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
@@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
// If fn returns an [io.Closer], it will be called to free resources.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
@@ -70,7 +70,7 @@ type AggregateFunction interface {
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current value of the aggregate.
|
||||
// Value is invoked to return the current (or final) value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
@@ -85,17 +85,6 @@ type WindowFunction interface {
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
@@ -106,20 +95,20 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
|
||||
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
|
||||
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
@@ -128,13 +117,13 @@ func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
21
sqlite.go
21
sqlite.go
@@ -43,7 +43,7 @@ func compileSQLite() {
|
||||
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
env = exportCallbacks(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
@@ -71,8 +71,6 @@ type sqlite struct {
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
type sqliteKey struct{}
|
||||
|
||||
func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
@@ -81,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
|
||||
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
|
||||
instance.compiled, wazero.NewModuleConfig())
|
||||
@@ -123,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
|
||||
reset: getFun("sqlite3_reset"),
|
||||
step: getFun("sqlite3_step"),
|
||||
exec: getFun("sqlite3_exec"),
|
||||
interrupt: getFun("sqlite3_interrupt"),
|
||||
progressHandler: getFun("sqlite3_progress_handler_go"),
|
||||
clearBindings: getFun("sqlite3_clear_bindings"),
|
||||
bindCount: getFun("sqlite3_bind_parameter_count"),
|
||||
bindIndex: getFun("sqlite3_bind_parameter_index"),
|
||||
@@ -342,6 +341,8 @@ type sqliteAPI struct {
|
||||
reset api.Function
|
||||
step api.Function
|
||||
exec api.Function
|
||||
interrupt api.Function
|
||||
progressHandler api.Function
|
||||
clearBindings api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
@@ -402,3 +403,15 @@ type sqliteAPI struct {
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
}
|
||||
|
||||
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncII(env, "go_progress", callbackProgress)
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "progress.c"
|
||||
#include "time.c"
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
|
||||
9
sqlite3/progress.c
Normal file
9
sqlite3/progress.c
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_progress(void *);
|
||||
|
||||
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
|
||||
sqlite3_progress_handler(db, n, go_progress, NULL);
|
||||
}
|
||||
@@ -23,7 +23,6 @@
|
||||
#define SQLITE_MAX_EXPR_DEPTH 0
|
||||
#define SQLITE_OMIT_DECLTYPE
|
||||
#define SQLITE_OMIT_DEPRECATED
|
||||
#define SQLITE_OMIT_PROGRESS_CALLBACK
|
||||
#define SQLITE_OMIT_SHARED_CACHE
|
||||
#define SQLITE_OMIT_AUTOINIT
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
@@ -137,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
return sqlite3_vfs_find_orig(zVfsName);
|
||||
}
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
|
||||
defer stmt.Close()
|
||||
|
||||
db.SetInterrupt(ctx)
|
||||
cancel()
|
||||
go cancel()
|
||||
|
||||
// Interrupting works.
|
||||
err = stmt.Exec()
|
||||
|
||||
4
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
4
vfs/tests/mptest/testdata/mptest.wasm.bz2
vendored
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:da493a827d5b2985ba80d7425092a891311633d6c80c559119f87609d0f0e02a
|
||||
size 508796
|
||||
oid sha256:c59231ce10786b45be958027d23cffc74894a00120b30c8d3accb26f4182b29a
|
||||
size 509312
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5ffb1dc0f046bb7e4ea4cd992e1d625c05c7d5cf33c78edc5a2155eb3d3c097
|
||||
size 523393
|
||||
oid sha256:9f715bad486eeae35ecb3cf05a2e6265fbfc24a2de0836bdc8fd760510ac1d3a
|
||||
size 524127
|
||||
|
||||
Reference in New Issue
Block a user