diff --git a/func.go b/func.go index 16b4305..1cd384b 100644 --- a/func.go +++ b/func.go @@ -132,7 +132,7 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn if win, ok := agg.(WindowFunction); ok { return win } - return windowFunc{agg, name} + return agg })) } rc := res_t(c.call("sqlite3_create_window_function_go", @@ -307,13 +307,3 @@ func (a *aggregateFunc) Close() error { a.stop() return nil } - -type windowFunc struct { - AggregateFunction - name string -} - -func (w windowFunc) Inverse(ctx Context, arg ...Value) { - // Implementing inverse allows certain queries that don't really need it to succeed. - ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function")) -} diff --git a/internal/util/func.go b/internal/util/func.go index e705f31..8e89b11 100644 --- a/internal/util/func.go +++ b/internal/util/func.go @@ -20,20 +20,6 @@ func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(con Export(name) } -type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1) - -func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { - _ = stack[1] // prevent bounds check on every slice access - fn(ctx, mod, T0(stack[0]), T1(stack[1])) -} - -func ExportFuncVII[T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1)) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(funcVII[T0, T1](fn), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, nil). - Export(name) -} - type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) { diff --git a/tests/conn_test.go b/tests/conn_test.go index 53a8922..aedcda7 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -12,7 +12,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" - _ "github.com/ncruces/go-sqlite3/vfs/memdb" + "github.com/ncruces/go-sqlite3/vfs/memdb" ) func TestConn_Open_dir(t *testing.T) { @@ -112,6 +112,48 @@ func TestConn_Close_BUSY(t *testing.T) { } } +func TestConn_BusyHandler(t *testing.T) { + t.Parallel() + + dsn := memdb.TestDB(t) + + db1, err := sqlite3.Open(dsn) + if err != nil { + t.Fatal(err) + } + defer db1.Close() + + db2, err := sqlite3.Open(dsn) + if err != nil { + t.Fatal(err) + } + defer db2.Close() + + var called bool + err = db2.BusyHandler(func(ctx context.Context, count int) (retry bool) { + called = true + return count < 1 + }) + if err != nil { + t.Fatal(err) + } + + tx, err := db1.BeginExclusive() + if err != nil { + t.Fatal(err) + } + defer tx.End(&err) + + _, err = db2.BeginExclusive() + if !errors.Is(err, sqlite3.BUSY) { + t.Errorf("got %v, want sqlite3.BUSY", err) + } + + if !called { + t.Error("busy handler not called") + } +} + func TestConn_SetInterrupt(t *testing.T) { t.Parallel() diff --git a/vfs/memdb/memdb.go b/vfs/memdb/memdb.go index b990704..5eec5a5 100644 --- a/vfs/memdb/memdb.go +++ b/vfs/memdb/memdb.go @@ -76,18 +76,16 @@ func (memVFS) FullPathname(name string) (string, error) { type memDB struct { name string + // +checklocks:lockMtx + waiter *sync.Cond // +checklocks:dataMtx data []*[sectorSize]byte - // +checklocks:dataMtx - size int64 - // +checklocks:memoryMtx - refs int32 - - shared int32 // +checklocks:lockMtx - pending bool // +checklocks:lockMtx - reserved bool // +checklocks:lockMtx - waiter *sync.Cond // +checklocks:lockMtx + size int64 // +checklocks:dataMtx + refs int32 // +checklocks:memoryMtx + shared int32 // +checklocks:lockMtx + pending bool // +checklocks:lockMtx + reserved bool // +checklocks:lockMtx lockMtx sync.Mutex dataMtx sync.RWMutex @@ -129,7 +127,7 @@ func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) { base := off / sectorSize rest := off % sectorSize have := int64(sectorSize) - if base == int64(len(m.data))-1 { + if m.size < off+int64(len(b)) { have = modRoundUp(m.size, sectorSize) } n = copy(b, (*m.data[base])[rest:have]) @@ -150,13 +148,13 @@ func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) { m.data = append(m.data, new([sectorSize]byte)) } n = copy((*m.data[base])[rest:], b) + if size := off + int64(n); size > m.size { + m.size = size + } if n < len(b) { // notest // assume writes are page aligned return n, io.ErrShortWrite } - if size := off + int64(len(b)); size > m.size { - m.size = size - } return n, nil }