diff --git a/internal/util/handle.go b/internal/util/handle.go new file mode 100644 index 0000000..c242a6f --- /dev/null +++ b/internal/util/handle.go @@ -0,0 +1,60 @@ +package util + +import ( + "context" + "io" +) + +type handleKey struct{} +type handleState struct { + handles []any +} + +func NewContext(ctx context.Context) (context.Context, io.Closer) { + state := new(handleState) + return context.WithValue(ctx, handleKey{}, state), state +} + +func (s *handleState) Close() (err error) { + for _, h := range s.handles { + if c, ok := h.(io.Closer); ok { + if e := c.Close(); err == nil { + err = e + } + } + } + s.handles = nil + return err +} + +func GetHandle(ctx context.Context, id uint32) any { + s := ctx.Value(handleKey{}).(*handleState) + return s.handles[id] +} + +func DelHandle(ctx context.Context, id uint32) error { + s := ctx.Value(handleKey{}).(*handleState) + a := s.handles[id] + s.handles[id] = nil + if c, ok := a.(io.Closer); ok { + return c.Close() + } + return nil + +} + +func AddHandle(ctx context.Context, a any) (id uint32) { + s := ctx.Value(handleKey{}).(*handleState) + + // Find an empty slot. + for id, h := range s.handles { + if h == nil { + s.handles[id] = a + return uint32(id) + } + } + + // Add a new slot. + s.handles = append(s.handles, a) + return uint32(len(s.handles) - 1) +} diff --git a/module.go b/module.go index fc6d027..33bb8a5 100644 --- a/module.go +++ b/module.go @@ -77,17 +77,17 @@ func compileModule() { } type module struct { - ctx context.Context - mod api.Module - vfs io.Closer - api sqliteAPI - arg [8]uint64 + ctx context.Context + mod api.Module + closer io.Closer + api sqliteAPI + stack [8]uint64 } func newModule(mod api.Module) (m *module, err error) { m = new(module) m.mod = mod - m.ctx, m.vfs = vfs.NewContext(context.Background()) + m.ctx, m.closer = util.NewContext(context.Background()) getFun := func(name string) api.Function { f := mod.ExportedFunction(name) @@ -164,7 +164,7 @@ func newModule(mod api.Module) (m *module, err error) { func (m *module) close() error { err := m.mod.Close(m.ctx) - m.vfs.Close() + m.closer.Close() return err } @@ -201,14 +201,14 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error { } func (m *module) call(fn api.Function, params ...uint64) uint64 { - copy(m.arg[:], params) - err := fn.CallWithStack(m.ctx, m.arg[:]) + copy(m.stack[:], params) + err := fn.CallWithStack(m.ctx, m.stack[:]) if err != nil { // The module closed or panicked; release resources. - m.vfs.Close() + m.closer.Close() panic(err) } - return m.arg[0] + return m.stack[0] } func (m *module) free(ptr uint32) { diff --git a/vfs/lock_test.go b/vfs/lock_test.go index 00bf000..cdbeaf4 100644 --- a/vfs/lock_test.go +++ b/vfs/lock_test.go @@ -41,8 +41,8 @@ func Test_vfsLock(t *testing.T) { pOutput = 32 ) mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize)) - ctx, vfs := NewContext(context.TODO()) - defer vfs.Close() + ctx, closer := util.NewContext(context.TODO()) + defer closer.Close() vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1}) vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2}) diff --git a/vfs/tests/mptest/mptest_test.go b/vfs/tests/mptest/mptest_test.go index fb538a8..8f6ab9e 100644 --- a/vfs/tests/mptest/mptest_test.go +++ b/vfs/tests/mptest/mptest_test.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "testing" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/vfs" _ "github.com/ncruces/go-sqlite3/vfs/memdb" "github.com/tetratelabs/wazero" @@ -82,16 +83,16 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 { cfg := config(ctx).WithArgs(args...) go func() { - ctx, vfs := vfs.NewContext(ctx) + ctx, closer := util.NewContext(ctx) mod, _ := rt.InstantiateModule(ctx, module, cfg) mod.Close(ctx) - vfs.Close() + closer.Close() }() return 0 } func Test_config01(t *testing.T) { - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "config01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -99,7 +100,7 @@ func Test_config01(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func Test_config02(t *testing.T) { @@ -110,7 +111,7 @@ func Test_config02(t *testing.T) { t.Skip("skipping in CI") } - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "config02.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -118,7 +119,7 @@ func Test_config02(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func Test_crash01(t *testing.T) { @@ -126,7 +127,7 @@ func Test_crash01(t *testing.T) { t.Skip("skipping in short mode") } - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "crash01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -134,7 +135,7 @@ func Test_crash01(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func Test_multiwrite01(t *testing.T) { @@ -142,7 +143,7 @@ func Test_multiwrite01(t *testing.T) { t.Skip("skipping in short mode") } - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -150,11 +151,11 @@ func Test_multiwrite01(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func Test_config01_memory(t *testing.T) { - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "test.db", "config01.test", "--vfs", "memdb", @@ -164,7 +165,7 @@ func Test_config01_memory(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func Test_multiwrite01_memory(t *testing.T) { @@ -172,7 +173,7 @@ func Test_multiwrite01_memory(t *testing.T) { t.Skip("skipping in short mode") } - ctx, vfs := vfs.NewContext(newContext(t)) + ctx, closer := util.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "/test.db", "multiwrite01.test", "--vfs", "memdb", @@ -182,7 +183,7 @@ func Test_multiwrite01_memory(t *testing.T) { t.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } func newContext(t *testing.T) context.Context { diff --git a/vfs/tests/speedtest1/speedtest1_test.go b/vfs/tests/speedtest1/speedtest1_test.go index 6c12348..c3e74ea 100644 --- a/vfs/tests/speedtest1/speedtest1_test.go +++ b/vfs/tests/speedtest1/speedtest1_test.go @@ -18,6 +18,7 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/vfs" _ "github.com/ncruces/go-sqlite3/vfs/memdb" ) @@ -74,7 +75,7 @@ func initFlags() { func Benchmark_speedtest1(b *testing.B) { output.Reset() - ctx, vfs := vfs.NewContext(context.Background()) + ctx, closer := util.NewContext(context.Background()) name := filepath.Join(b.TempDir(), "test.db") args := append(options, "--size", strconv.Itoa(b.N), name) cfg := wazero.NewModuleConfig(). @@ -88,5 +89,5 @@ func Benchmark_speedtest1(b *testing.B) { b.Error(err) } mod.Close(ctx) - vfs.Close() + closer.Close() } diff --git a/vfs/vfs.go b/vfs/vfs.go index e50bacf..f671487 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -44,33 +44,6 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder return env } -type vfsKey struct{} -type vfsState struct { - files []File -} - -// NewContext is an internal API users need not call directly. -// -// NewContext creates a new context to hold [api.Module] specific VFS data. -// The context should be passed to any [api.Function] calls that might -// generate VFS host callbacks. -// The returned [io.Closer] should be closed after the [api.Module] is closed, -// to release any associated resources. -func NewContext(ctx context.Context) (context.Context, io.Closer) { - vfs := new(vfsState) - return context.WithValue(ctx, vfsKey{}, vfs), vfs -} - -func (vfs *vfsState) Close() error { - for _, f := range vfs.files { - if f != nil { - f.Close() - } - } - vfs.files = nil - return nil -} - func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { name := util.ReadString(mod, zVfsName, _MAX_STRING) if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) { @@ -431,40 +404,22 @@ func vfsGet(mod api.Module, pVfs uint32) VFS { panic(util.NoVFSErr + util.ErrorString(name)) } -func vfsFileNew(vfs *vfsState, file File) uint32 { - // Find an empty slot. - for id, f := range vfs.files { - if f == nil { - vfs.files[id] = file - return uint32(id) - } - } - - // Add a new slot. - vfs.files = append(vfs.files, file) - return uint32(len(vfs.files) - 1) -} - func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) { const fileHandleOffset = 4 - id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file) + id := util.AddHandle(ctx, file) util.WriteUint32(mod, pFile+fileHandleOffset, id) } func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File { const fileHandleOffset = 4 - vfs := ctx.Value(vfsKey{}).(*vfsState) id := util.ReadUint32(mod, pFile+fileHandleOffset) - return vfs.files[id] + return util.GetHandle(ctx, id).(File) } func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error { const fileHandleOffset = 4 - vfs := ctx.Value(vfsKey{}).(*vfsState) id := util.ReadUint32(mod, pFile+fileHandleOffset) - file := vfs.files[id] - vfs.files[id] = nil - return file.Close() + return util.DelHandle(ctx, id) } func vfsErrorCode(err error, def _ErrorCode) _ErrorCode { diff --git a/vfs/vfs_test.go b/vfs/vfs_test.go index 8562afe..ea1512d 100644 --- a/vfs/vfs_test.go +++ b/vfs/vfs_test.go @@ -220,8 +220,8 @@ func Test_vfsAccess(t *testing.T) { func Test_vfsFile(t *testing.T) { mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize)) - ctx, vfs := NewContext(context.TODO()) - defer vfs.Close() + ctx, closer := util.NewContext(context.TODO()) + defer closer.Close() // Open a temporary file. rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0) @@ -293,8 +293,8 @@ func Test_vfsFile(t *testing.T) { func Test_vfsFile_psow(t *testing.T) { mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize)) - ctx, vfs := NewContext(context.TODO()) - defer vfs.Close() + ctx, closer := util.NewContext(context.TODO()) + defer closer.Close() // Open a temporary file. rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)