Interrupts: avoid goroutine.

This commit is contained in:
Nuno Cruces
2023-10-25 12:56:52 +01:00
parent 6353160619
commit 2157d0f325
13 changed files with 74 additions and 80 deletions

73
conn.go
View File

@@ -7,10 +7,9 @@ import (
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Conn is a database connection handle.
@@ -21,7 +20,6 @@ type Conn struct {
*sqlite
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
@@ -48,6 +46,8 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return newConn(filename, flags)
}
type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
sqlite, err := instantiateSQLite()
if err != nil {
@@ -63,6 +63,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err != nil {
return nil, err
@@ -131,7 +132,6 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
@@ -244,65 +244,40 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
return ctx
}
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
// An uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
} else {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
// Remove the handler if the context can't be canceled.
if ctx == nil || ctx.Done() == nil {
c.call(c.api.progressHandler, uint64(c.handle), 0)
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
if c.checkInterrupt() {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter: // Waiter was cancelled.
break
case <-ctx.Done(): // Done was closed.
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
}
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
c.call(c.api.progressHandler, uint64(c.handle), 100)
return old
}
func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
func callbackProgress(ctx context.Context, mod api.Module, _ uint32) uint32 {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt != nil && c.interrupt.Err() != nil {
return 1
}
}
return 0
}
func (c *Conn) checkInterrupt() {
if c.interrupt != nil && c.interrupt.Err() != nil {
c.call(c.api.interrupt, uint64(c.handle))
}
const isInterruptedOffset = 288
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
// Pragma executes a PRAGMA statement and returns any results.

View File

@@ -31,6 +31,7 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -225,7 +226,13 @@ func (c *conn) Commit() error {
}
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
err := c.Conn.Exec(c.txRollback)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(c.txRollback)
}
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {

View File

@@ -13,6 +13,8 @@ sqlite3_finalize
sqlite3_reset
sqlite3_step
sqlite3_exec
sqlite3_interrupt
sqlite3_progress_handler_go
sqlite3_clear_bindings
sqlite3_bind_parameter_count
sqlite3_bind_parameter_index

Binary file not shown.

25
func.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -47,6 +46,7 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://www.sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
@@ -70,7 +70,7 @@ type AggregateFunction interface {
// The function arguments, if any, corresponding to the row being added are passed to Step.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current value of the aggregate.
// Value is invoked to return the current (or final) value of the aggregate.
Value(ctx Context)
}
@@ -85,17 +85,6 @@ type WindowFunction interface {
Inverse(ctx Context, arg ...Value)
}
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
util.DelHandle(ctx, pApp)
}
@@ -106,20 +95,20 @@ func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK
}
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
var handle uint32
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
if err := util.DelHandle(ctx, handle); err != nil {
@@ -128,13 +117,13 @@ func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
}
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
fn.Value(Context{sqlite, pCtx})
}
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
sqlite := ctx.Value(connKey{}).(*Conn).sqlite
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
}

View File

@@ -43,7 +43,7 @@ func compileSQLite() {
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportHostFunctions(env)
env = exportCallbacks(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
@@ -71,8 +71,6 @@ type sqlite struct {
stack [8]uint64
}
type sqliteKey struct{}
func instantiateSQLite() (sqlt *sqlite, err error) {
instance.once.Do(compileSQLite)
if instance.err != nil {
@@ -81,7 +79,6 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig())
@@ -123,6 +120,8 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
reset: getFun("sqlite3_reset"),
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
interrupt: getFun("sqlite3_interrupt"),
progressHandler: getFun("sqlite3_progress_handler_go"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindIndex: getFun("sqlite3_bind_parameter_index"),
@@ -342,6 +341,8 @@ type sqliteAPI struct {
reset api.Function
step api.Function
exec api.Function
interrupt api.Function
progressHandler api.Function
clearBindings api.Function
bindCount api.Function
bindIndex api.Function
@@ -402,3 +403,15 @@ type sqliteAPI struct {
resultErrorBig api.Function
destructor uint32
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress", callbackProgress)
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
util.ExportFuncVIII(env, "go_func", callbackFunc)
util.ExportFuncVIII(env, "go_step", callbackStep)
util.ExportFuncVI(env, "go_final", callbackFinal)
util.ExportFuncVI(env, "go_value", callbackValue)
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
return env
}

View File

@@ -11,6 +11,7 @@
#include "ext/uint.c"
#include "ext/uuid.c"
#include "func.c"
#include "progress.c"
#include "time.c"
__attribute__((constructor)) void init() {

9
sqlite3/progress.c Normal file
View File

@@ -0,0 +1,9 @@
#include <stddef.h>
#include "sqlite3.h"
int go_progress(void *);
void sqlite3_progress_handler_go(sqlite3 *db, int n) {
sqlite3_progress_handler(db, n, go_progress, NULL);
}

View File

@@ -23,7 +23,6 @@
#define SQLITE_MAX_EXPR_DEPTH 0
#define SQLITE_OMIT_DECLTYPE
#define SQLITE_OMIT_DEPRECATED
#define SQLITE_OMIT_PROGRESS_CALLBACK
#define SQLITE_OMIT_SHARED_CACHE
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA

View File

@@ -137,6 +137,5 @@ sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");

View File

@@ -182,7 +182,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer stmt.Close()
db.SetInterrupt(ctx)
cancel()
go cancel()
// Interrupting works.
err = stmt.Exec()

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:da493a827d5b2985ba80d7425092a891311633d6c80c559119f87609d0f0e02a
size 508796
oid sha256:c59231ce10786b45be958027d23cffc74894a00120b30c8d3accb26f4182b29a
size 509312

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5ffb1dc0f046bb7e4ea4cd992e1d625c05c7d5cf33c78edc5a2155eb3d3c097
size 523393
oid sha256:9f715bad486eeae35ecb3cf05a2e6265fbfc24a2de0836bdc8fd760510ac1d3a
size 524127