From 14dc5cabd443fa599d3027b233bfaa367ffef502 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 18 Jan 2023 12:44:14 +0000 Subject: [PATCH] Context. --- conn.go | 29 ++++++++++++++++------------- stmt.go | 29 ++++++++++++++--------------- vfs.go | 5 ++++- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/conn.go b/conn.go index a14f4c6..d04a518 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( ) type Conn struct { + ctx context.Context handle uint32 module api.Module memory api.Memory @@ -36,8 +37,7 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) { } } - ctx := context.TODO() - + ctx := context.Background() cfg := wazero.NewModuleConfig(). WithName("sqlite3-" + strconv.FormatUint(counter.Add(1), 10)) if fs != nil { @@ -54,12 +54,13 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) { }() c := newConn(module) + c.ctx = context.WithValue(ctx, connContext{}, c) namePtr := c.newString(name) connPtr := c.new(ptrSize) defer c.free(namePtr) defer c.free(connPtr) - r, err := c.api.open.Call(ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0) + r, err := c.api.open.Call(c.ctx, uint64(namePtr), uint64(connPtr), uint64(flags), 0) if err != nil { return nil, err } @@ -73,7 +74,7 @@ func OpenFlags(name string, flags OpenFlag) (conn *Conn, err error) { } func (c *Conn) Close() error { - r, err := c.api.close.Call(context.TODO(), uint64(c.handle)) + r, err := c.api.close.Call(c.ctx, uint64(c.handle)) if err != nil { return err } @@ -81,14 +82,14 @@ func (c *Conn) Close() error { if err := c.error(r[0]); err != nil { return err } - return c.module.Close(context.TODO()) + return c.module.Close(c.ctx) } func (c *Conn) Exec(sql string) error { sqlPtr := c.newString(sql) defer c.free(sqlPtr) - r, err := c.api.exec.Call(context.TODO(), uint64(c.handle), uint64(sqlPtr), 0, 0, 0) + r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) if err != nil { return err } @@ -107,7 +108,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str defer c.free(stmtPtr) defer c.free(tailPtr) - r, err := c.api.prepare.Call(context.TODO(), uint64(c.handle), + r, err := c.api.prepare.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), uint64(stmtPtr), uint64(tailPtr)) if err != nil { @@ -141,13 +142,13 @@ func (c *Conn) error(rc uint64) error { var r []uint64 // string - r, _ = c.api.errstr.Call(context.TODO(), rc) + r, _ = c.api.errstr.Call(c.ctx, rc) if r != nil { serr.str = c.getString(uint32(r[0]), 512) } // message - r, _ = c.api.errmsg.Call(context.TODO(), uint64(c.handle)) + r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) if r != nil { serr.msg = c.getString(uint32(r[0]), 512) } @@ -164,14 +165,14 @@ func (c *Conn) free(ptr uint32) { if ptr == 0 { return } - _, err := c.api.free.Call(context.TODO(), uint64(ptr)) + _, err := c.api.free.Call(c.ctx, uint64(ptr)) if err != nil { panic(err) } } func (c *Conn) new(len uint32) uint32 { - r, err := c.api.malloc.Call(context.TODO(), uint64(len)) + r, err := c.api.malloc.Call(c.ctx, uint64(len)) if err != nil { panic(err) } @@ -190,7 +191,7 @@ func (c *Conn) newBytes(s []byte) uint32 { ptr := c.new(siz) mem, ok := c.memory.Read(ptr, siz) if !ok { - c.api.free.Call(context.TODO(), uint64(ptr)) + c.api.free.Call(c.ctx, uint64(ptr)) panic("sqlite3: out of range") } @@ -203,7 +204,7 @@ func (c *Conn) newString(s string) uint32 { ptr := c.new(siz) mem, ok := c.memory.Read(ptr, siz) if !ok { - c.api.free.Call(context.TODO(), uint64(ptr)) + c.api.free.Call(c.ctx, uint64(ptr)) panic("sqlite3: out of range") } @@ -233,4 +234,6 @@ func getString(memory api.Memory, ptr, maxlen uint32) string { } } +type connContext struct{} + const ptrSize = 4 diff --git a/stmt.go b/stmt.go index 108a957..13df004 100644 --- a/stmt.go +++ b/stmt.go @@ -1,7 +1,6 @@ package sqlite3 import ( - "context" "math" ) @@ -11,7 +10,7 @@ type Stmt struct { } func (s *Stmt) Close() error { - r, err := s.c.api.finalize.Call(context.TODO(), uint64(s.handle)) + r, err := s.c.api.finalize.Call(s.c.ctx, uint64(s.handle)) if err != nil { return err } @@ -21,7 +20,7 @@ func (s *Stmt) Close() error { } func (s *Stmt) Reset() error { - r, err := s.c.api.reset.Call(context.TODO(), uint64(s.handle)) + r, err := s.c.api.reset.Call(s.c.ctx, uint64(s.handle)) if err != nil { return err } @@ -29,7 +28,7 @@ func (s *Stmt) Reset() error { } func (s *Stmt) Step() (row bool, err error) { - r, err := s.c.api.step.Call(context.TODO(), uint64(s.handle)) + r, err := s.c.api.step.Call(s.c.ctx, uint64(s.handle)) if err != nil { return false, err } @@ -54,7 +53,7 @@ func (s *Stmt) BindInt(param int, value int) error { } func (s *Stmt) BindInt64(param int, value int64) error { - r, err := s.c.api.bindInteger.Call(context.TODO(), + r, err := s.c.api.bindInteger.Call(s.c.ctx, uint64(s.handle), uint64(param), uint64(value)) if err != nil { return err @@ -63,7 +62,7 @@ func (s *Stmt) BindInt64(param int, value int64) error { } func (s *Stmt) BindFloat(param int, value float64) error { - r, err := s.c.api.bindFloat.Call(context.TODO(), + r, err := s.c.api.bindFloat.Call(s.c.ctx, uint64(s.handle), uint64(param), math.Float64bits(value)) if err != nil { return err @@ -73,7 +72,7 @@ func (s *Stmt) BindFloat(param int, value float64) error { func (s *Stmt) BindText(param int, value string) error { ptr := s.c.newString(value) - r, err := s.c.api.bindText.Call(context.TODO(), + r, err := s.c.api.bindText.Call(s.c.ctx, uint64(s.handle), uint64(param), uint64(ptr), uint64(len(value)), s.c.api.destructor, _UTF8) @@ -85,7 +84,7 @@ func (s *Stmt) BindText(param int, value string) error { func (s *Stmt) BindBlob(param int, value []byte) error { ptr := s.c.newBytes(value) - r, err := s.c.api.bindBlob.Call(context.TODO(), + r, err := s.c.api.bindBlob.Call(s.c.ctx, uint64(s.handle), uint64(param), uint64(ptr), uint64(len(value)), s.c.api.destructor) @@ -96,7 +95,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error { } func (s *Stmt) BindNull(param int) error { - r, err := s.c.api.bindNull.Call(context.TODO(), + r, err := s.c.api.bindNull.Call(s.c.ctx, uint64(s.handle), uint64(param)) if err != nil { return err @@ -116,7 +115,7 @@ func (s *Stmt) ColumnInt(col int) int { } func (s *Stmt) ColumnInt64(col int) int64 { - r, err := s.c.api.columnInteger.Call(context.TODO(), + r, err := s.c.api.columnInteger.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -125,7 +124,7 @@ func (s *Stmt) ColumnInt64(col int) int64 { } func (s *Stmt) ColumnFloat(col int) float64 { - r, err := s.c.api.columnInteger.Call(context.TODO(), + r, err := s.c.api.columnInteger.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -134,7 +133,7 @@ func (s *Stmt) ColumnFloat(col int) float64 { } func (s *Stmt) ColumnText(col int) string { - r, err := s.c.api.columnText.Call(context.TODO(), + r, err := s.c.api.columnText.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -146,7 +145,7 @@ func (s *Stmt) ColumnText(col int) string { return "" } - r, err = s.c.api.columnBytes.Call(context.TODO(), + r, err = s.c.api.columnBytes.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -160,7 +159,7 @@ func (s *Stmt) ColumnText(col int) string { } func (s *Stmt) ColumnBlob(col int, buf []byte) int { - r, err := s.c.api.columnBlob.Call(context.TODO(), + r, err := s.c.api.columnBlob.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -172,7 +171,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) int { return 0 } - r, err = s.c.api.columnBytes.Call(context.TODO(), + r, err = s.c.api.columnBytes.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) diff --git a/vfs.go b/vfs.go index ed033eb..428f284 100644 --- a/vfs.go +++ b/vfs.go @@ -74,7 +74,10 @@ func vfsCurrentTime64(ctx context.Context, mod api.Module, vfs, out uint32) uint func vfsOpen(ctx context.Context, mod api.Module, vfs, zName, file, flags, pOutFlags uint32) uint32 { name := getString(mod.Memory(), zName, _MAX_PATHNAME) - log.Println("vfsOpen", name) + c, ok := ctx.Value(connContext{}).(*Conn) + if ok && mod == c.module { + log.Println("vfsOpen", name) + } return uint32(IOERR) }