Compare commits

...

18 Commits

Author SHA1 Message Date
Nuno Cruces
9f626b2f52 Fix #255. 2025-03-31 16:33:31 +01:00
Nuno Cruces
1f5d8bf7df Avoid escaping times (#256) 2025-03-31 13:02:41 +01:00
Nuno Cruces
41dc46af7e Optimize errors a bit. 2025-03-28 17:01:04 +00:00
Nuno Cruces
e5c285b783 Discussion #250. 2025-03-28 11:30:47 +00:00
Nuno Cruces
6290a14990 Fix interrupt race. 2025-03-26 19:02:14 +00:00
Nuno Cruces
948641194b Rework context cancellation. (#251) 2025-03-26 11:39:06 +00:00
Nuno Cruces
befed7cf23 binaryen-version_123. 2025-03-26 11:25:54 +00:00
Nuno Cruces
3547d9ffb0 Fix WAL flakiness on Windows (#254) 2025-03-26 10:17:19 +00:00
Nuno Cruces
a67165eb09 Fix WAL flakiness on Windows (#253) 2025-03-26 02:13:30 +00:00
Nuno Cruces
0ba393199a Mitigate #252. 2025-03-25 23:36:57 +00:00
Nuno Cruces
9e4258bc46 Better fix #249. 2025-03-25 12:45:27 +00:00
Nuno Cruces
b645721d10 IP/CIDR functions. (#246) 2025-03-24 22:38:22 +00:00
Nuno Cruces
6c296231a5 Fix #249. 2025-03-24 19:59:19 +00:00
Nuno Cruces
c067e3630b Improve SetInterrupt performance. 2025-03-24 10:41:50 +00:00
Nuno Cruces
35a2dbd847 Improve context cancellation performance. (#248) 2025-03-21 11:06:29 +00:00
Nuno Cruces
b36f73c66d Optimize. 2025-03-17 12:24:36 +00:00
Nuno Cruces
d36f19fd91 Docs. 2025-03-14 11:37:48 +00:00
Nuno Cruces
eba71b1f42 Go 1.24.1. 2025-03-14 00:54:12 +00:00
40 changed files with 771 additions and 522 deletions

View File

@@ -3,13 +3,13 @@ set -euo pipefail
if [[ "$OSTYPE" == "linux"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-x86_64-linux.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_123/binaryen-version_123-x86_64-linux.tar.gz"
elif [[ "$OSTYPE" == "darwin"* ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-arm64-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-arm64-macos.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_123/binaryen-version_123-arm64-macos.tar.gz"
elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then
WASI_SDK="https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-25/wasi-sdk-25.0-x86_64-windows.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_122/binaryen-version_122-x86_64-windows.tar.gz"
BINARYEN="https://github.com/WebAssembly/binaryen/releases/download/version_123/binaryen-version_123-x86_64-windows.tar.gz"
fi
# Download tools

View File

@@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1
}
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
@@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
b.c.checkInterrupt(b.c.handle)
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0

View File

@@ -275,6 +275,10 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}
defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
@@ -388,6 +392,6 @@ func (c *Conn) EnableChecksums(schema string) error {
}
// Checkpoint the WAL.
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_RESTART)
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL)
return err
}

74
conn.go
View File

@@ -25,7 +25,6 @@ type Conn struct {
*sqlite
interrupt context.Context
pending *Stmt
stmts []*Stmt
busy func(context.Context, int) bool
log func(xErrorCode, string)
@@ -41,6 +40,7 @@ type Conn struct {
busylst time.Time
arena arena
handle ptr_t
gosched uint8
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
@@ -120,7 +120,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
return 0, err
}
c.call("sqlite3_progress_handler_go", stk_t(handle), 100)
c.call("sqlite3_progress_handler_go", stk_t(handle), 1000)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
@@ -132,7 +132,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
}
}
if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String())
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
@@ -166,9 +165,6 @@ func (c *Conn) Close() error {
return nil
}
c.pending.Close()
c.pending = nil
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
@@ -183,11 +179,16 @@ func (c *Conn) Close() error {
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
defer c.arena.mark()()
sqlPtr := c.arena.string(sql)
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0))
func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
}
@@ -206,20 +207,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
textPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
stmt = &Stmt{c: c}
stmt = &Stmt{c: c, sql: sql}
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" {
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" {
tail = sql
}
@@ -340,43 +343,17 @@ func (c *Conn) GetInterrupt() context.Context {
//
// https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
if ctx == nil {
panic("nil Context")
}
old = c.interrupt
c.interrupt = ctx
if ctx == old || ctx.Done() == old.Done() {
return old
}
// A busy SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == nil {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64,
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
c.pending = &Stmt{c: c}
c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr)
}
if old.Done() != nil && ctx.Err() == nil {
c.pending.Reset()
}
if ctx.Done() != nil {
c.pending.Step()
}
return old
}
func (c *Conn) checkInterrupt(handle ptr_t) {
if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", stk_t(handle))
}
}
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt.Done() != nil {
if c.gosched++; c.gosched%16 == 0 {
runtime.Gosched()
}
if c.interrupt.Err() != nil {
@@ -432,11 +409,8 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool))
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
interrupt := c.interrupt
if interrupt == nil {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
if interrupt := c.interrupt; interrupt.Err() == nil &&
c.busy(interrupt, int(count)) {
retry = 1
}
}

View File

@@ -89,20 +89,26 @@ func (ctx Context) ResultText(value string) {
}
// ResultRawText sets the text result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultRawText(value []byte) {
if len(value) == 0 {
ctx.ResultText("")
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBlob(value []byte) {
if len(value) == 0 {
ctx.ResultZeroBlob(0)
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_blob_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))

View File

@@ -20,22 +20,45 @@
// - a [serializable] transaction is always "immediate";
// - a [read-only] transaction is always "deferred".
//
// # Datatypes In SQLite
//
// SQLite is dynamically typed.
// Columns can mostly hold any value regardless of their declared type.
// SQLite supports most [driver.Value] types out of the box,
// but bool and [time.Time] require special care.
//
// Booleans can be stored on any column type and scanned back to a *bool.
// However, if scanned to a *any, booleans may either become an
// int64, string or bool, depending on the declared type of the column.
// If you use BOOLEAN for your column type,
// 1 and 0 will always scan as true and false.
//
// # Working with time
//
// Time values can similarly be stored on any column type.
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// Special values are: "auto" (the default), "sqlite", "rfc3339";
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// - "rfc3339" encodes and decodes RFC 3339 only.
//
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
//
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce time-ordered sequences.
//
// If you encode as RFC 3339 (the default),
// time values will scan back to a *time.Time unless your column type is TEXT.
// Otherwise, if scanned to a *any, time values may either become an
// int64, float64 or string, depending on the time format and declared type of the column.
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
// "_timefmt" will be used to decode values.
//
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
//
// When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner].
@@ -48,7 +71,7 @@
// The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and what whoever wrote the value.
// depending on the column type and whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help.
//
// # Setting PRAGMAs
@@ -358,13 +381,10 @@ func (c *conn) Commit() error {
}
func (c *conn) Rollback() error {
err := c.Conn.Exec(`ROLLBACK` + c.txReset)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(`ROLLBACK` + c.txReset)
}
return err
// ROLLBACK even if interrupted.
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
return c.Conn.Exec(`ROLLBACK` + c.txReset)
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@@ -598,6 +618,28 @@ const (
_TIME
)
func scanFromDecl(decl string) scantype {
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch decl {
case "INT", "INTEGER":
return _INT
case "REAL":
return _REAL
case "TEXT":
return _TEXT
case "BLOB":
return _BLOB
case "BOOLEAN":
return _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
return _TIME
}
return _ANY
}
var (
// Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -622,6 +664,18 @@ func (r *rows) Columns() []string {
return r.names
}
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
count := r.Stmt.ColumnCount()
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
return r.scans[index]
}
func (r *rows) loadColumnMetadata() {
if r.nulls == nil {
count := r.Stmt.ColumnCount()
@@ -635,24 +689,7 @@ func (r *rows) loadColumnMetadata() {
r.Stmt.ColumnTableName(i),
col)
types[i] = strings.ToUpper(types[i])
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch types[i] {
case "INT", "INTEGER":
scans[i] = _INT
case "REAL":
scans[i] = _REAL
case "TEXT":
scans[i] = _TEXT
case "BLOB":
scans[i] = _BLOB
case "BOOLEAN":
scans[i] = _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
scans[i] = _TIME
}
scans[i] = scanFromDecl(types[i])
}
}
r.nulls = nulls
@@ -661,27 +698,15 @@ func (r *rows) loadColumnMetadata() {
}
}
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
types := make([]string, count)
for i := range types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
r.types = types
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
r.loadColumnMetadata()
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
decl := r.types[index]
if len := len(decl); len > 0 && decl[len-1] == ')' {
if i := strings.LastIndexByte(decl, '('); i >= 0 {
decl = decl[:i]
}
}
return strings.TrimSpace(decltype)
return strings.TrimSpace(decl)
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
@@ -748,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error {
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data...)
if err := r.Stmt.ColumnsRaw(data...); err != nil {
return err
}
for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
}
}
return err
}
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch v := v.(type) {
case int64, float64:
// could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
scan := r.scanType(i)
switch v := dest[i].(type) {
case int64:
if scan == _BOOL {
switch v {
case 1:
dest[i] = true
case 0:
dest[i] = false
}
continue
}
case []byte:
if len(v) == cap(v) { // a BLOB
continue
}
if scan != _TEXT {
switch r.tmWrite {
case "", time.RFC3339, time.RFC3339Nano:
t, ok := maybeTime(v)
if ok {
dest[i] = t
continue
}
}
}
dest[i] = string(v)
case float64:
break
default:
continue
}
t, ok := maybeTime(v)
if ok {
return t, true
if scan == _TIME {
t, err := r.tmRead.Decode(dest[i])
if err == nil {
dest[i] = t
continue
}
}
default:
return
}
switch r.declType(i) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// could be a time value
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
return nil
}

View File

@@ -199,6 +199,62 @@ func Test_BeginTx(t *testing.T) {
}
}
func Test_nested_context(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer outer.Close()
want := func(rows *sql.Rows, want int) {
t.Helper()
var got int
rows.Next()
if err := rows.Scan(&got); err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
want(outer, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer inner.Close()
want(inner, 0)
cancel()
if inner.Next() || !errors.Is(inner.Err(), sqlite3.INTERRUPT) {
t.Fatal(inner.Err())
}
want(outer, 1)
}
func Test_Prepare(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
@@ -467,3 +523,29 @@ func Test_ColumnType_ScanType(t *testing.T) {
t.Fatal(err)
}
}
func Benchmark_loop(b *testing.B) {
db, err := Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
var version string
err = db.QueryRow(`SELECT sqlite_version();`).Scan(&version)
if err != nil {
b.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
b.Cleanup(cancel)
b.ResetTimer()
for range b.N {
_, err := db.ExecContext(ctx,
`WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -1,9 +1,5 @@
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
package driver_test
// Adapted from: https://go.dev/doc/tutorial/database-access
import (
"database/sql"
"database/sql/driver"
@@ -27,7 +23,7 @@ func Example_customTime() {
_, err = db.Exec(`
CREATE TABLE data (
id INTEGER PRIMARY KEY,
date_time TEXT
date_time ANY
) STRICT;
`)
if err != nil {

View File

@@ -1,12 +1,15 @@
package driver
import "time"
import (
"bytes"
"time"
)
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) (_ time.Time, _ bool) {
func maybeTime(text []byte) (_ time.Time, _ bool) {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
@@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
// Slow path.
var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date, true
}
return

View File

@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
v, ok := maybeTime(str)
v, ok := maybeTime([]byte(str))
if ok {
// Make sure times round-trip to the same string:
// https://pkg.go.dev/database/sql#Rows.Scan
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
checkTime := func(t testing.TB, date time.Time) {
v, ok := maybeTime(date.Format(time.RFC3339Nano))
v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
if ok {
// Make sure times round-trip to the same time:
if !v.Equal(date) {

View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"errors"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,7 +11,6 @@ import (
//
// https://sqlite.org/c3ref/errcode.html
type Error struct {
str string
msg string
sql string
code res_t
@@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode {
//
// https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e.code)
return xErrorCode(e.code)
}
// Error implements the error interface.
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("sqlite3: ")
if e.str != "" {
b.WriteString(e.str)
} else {
b.WriteString(strconv.Itoa(int(e.code)))
}
b.WriteString(util.ErrorCodeString(uint32(e.code)))
if e.msg != "" {
b.WriteString(": ")
@@ -103,12 +95,12 @@ func (e ErrorCode) Error() string {
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
return e == BUSY || e == INTERRUPT
}
// ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e)
return xErrorCode(e)
}
// Error implements the error interface.
@@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool {
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
}
// Timeout returns true for [BUSY_TIMEOUT] errors.

View File

@@ -43,7 +43,7 @@ func TestError(t *testing.T) {
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
if s := err.Error(); s != "sqlite3: unknown error" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
@@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) {
}
}
{
err := ExtendedErrorCode(tt.code)
err := xErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) {
}
}
{
err := ExtendedErrorCode(tt.code)
err := xErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
for i := 0; i == int(xErrorCode(i)); i++ {
want := "sqlite3: "
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
got := ExtendedErrorCode(i).Error()
got := xErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}

113
ext/ipaddr/ipaddr.go Normal file
View File

@@ -0,0 +1,113 @@
// Package ipaddr provides functions to manipulate IPs and CIDRs.
//
// It provides the following functions:
// - ipcontains(prefix, ip)
// - ipoverlaps(prefix1, prefix2)
// - ipfamily(ip/prefix)
// - iphost(ip/prefix)
// - ipmasklen(prefix)
// - ipnetwork(prefix)
package ipaddr
import (
"errors"
"net/netip"
"github.com/ncruces/go-sqlite3"
)
// Register IP/CIDR functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("ipcontains", 2, flags, contains),
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
db.CreateFunction("ipfamily", 1, flags, family),
db.CreateFunction("iphost", 1, flags, host),
db.CreateFunction("ipmasklen", 1, flags, masklen),
db.CreateFunction("ipnetwork", 1, flags, network))
}
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
addr, err := netip.ParseAddr(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix.Contains(addr))
}
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix1, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
prefix2, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix1.Overlaps(prefix2))
}
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
switch {
case addr.Is4():
ctx.ResultInt(4)
case addr.Is6():
ctx.ResultInt(6)
}
}
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := addr.MarshalText()
ctx.ResultRawText(buf)
}
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultInt(prefix.Bits())
}
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := prefix.Masked().MarshalText()
ctx.ResultRawText(buf)
}
func addr(text string) (netip.Addr, error) {
addr, err := netip.ParseAddr(text)
if err != nil {
if prefix, err := netip.ParsePrefix(text); err == nil {
return prefix.Addr(), nil
}
if addrpt, err := netip.ParseAddrPort(text); err == nil {
return addrpt.Addr(), nil
}
}
return addr, err
}

88
ext/ipaddr/ipaddr_test.go Normal file
View File

@@ -0,0 +1,88 @@
package ipaddr_test
import (
"testing"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/ipaddr"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestRegister(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := driver.Open(tmp, ipaddr.Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got string
err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "6" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "6" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "4" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "192.168.1.5" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "24" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "192.168.1.0/24" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "1" {
t.Fatalf("got %s", got)
}
err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != "1" {
t.Fatalf("got %s", got)
}
}

View File

@@ -1,21 +1,22 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
// - upper() and lower() functions
// - LIKE and REGEXP operators
// - collation sequences
//
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regexp/syntax];
// - collation sequences use [collate].
// Like PostgreSQL, it also provides:
// - initcap()
// - casefold()
// - normalize()
// - unaccent()
//
// It also provides (approximately) from PostgreSQL:
// - casefold(),
// - initcap(),
// - normalize(),
// - unaccent().
// The implementations are not 100% compatible:
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
// - normalize(), unaccent() use [transform] and [unicode.Mn]
// - the LIKE operator follows [strings.EqualFold] rules
// - the REGEXP operator uses Go [regexp/syntax]
// - collation sequences use [collate]
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
@@ -27,6 +28,7 @@ import (
"errors"
"regexp"
"strings"
"sync"
"unicode"
"unicode/utf8"
@@ -159,8 +161,16 @@ func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
}
var unaccentPool = sync.Pool{
New: func() any {
return transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
},
}
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
unaccent := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
unaccent := unaccentPool.Get().(transform.Transformer)
defer unaccentPool.Put(unaccent)
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest

View File

@@ -17,17 +17,18 @@ import (
// Register registers the SQL functions:
//
// uuid([version], [domain/namespace], [id/data])
//
// Generates a UUID as a string.
//
// uuid_str(u)
//
// Converts a UUID into a well-formed UUID string.
//
// uuid_blob(u)
//
// Converts a UUID into a 16-byte blob.
// - uuid([ version [, domain/namespace, [ id/data ]]]):
// to generate a UUID as a string
// - uuid_str(u):
// to convert a UUID into a well-formed UUID string
// - uuid_blob(u):
// to convert a UUID into a 16-byte blob
// - uuid_extract_version(u):
// to extract the version of a RFC 4122 UUID
// - uuid_extract_timestamp(u):
// to extract the timestamp of a version 1/2/6/7 UUID
// - gen_random_uuid(u):
// to generate a version 4 (random) UUID
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
@@ -38,7 +39,8 @@ func Register(db *sqlite3.Conn) error {
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob),
db.CreateFunction("uuid_extract_version", 1, flags, version),
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp))
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp),
db.CreateFunction("gen_random_uuid", 0, sqlite3.INNOCUOUS, generate))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -284,10 +284,10 @@ func returnArgs(p *[]Value) {
}
type aggregateFunc struct {
ctx Context
arg []Value
next func() (struct{}, bool)
stop func()
ctx Context
arg []Value
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {

View File

@@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string {
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
case EMPTY:
break
case SCHEMA:
return "sqlite3: database schema has changed"
@@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string {
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
case FORMAT:
break
case RANGE:
return "sqlite3: column index out of range"

View File

@@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
if i := bytes.IndexByte(buf, 0); i >= 0 {
return string(buf[:i])
}
panic(NoNulErr)
}
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {

View File

@@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
panic(util.OOMErr)
}
if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 {
err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME)
}
if handle != 0 {
var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = ""
}
}
if len(sql) != 0 {
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
err.sql = sql[0][i:]
query = sql[0][i:]
}
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
if msg != "" || query != "" {
return &Error{code: rc, msg: msg, sql: query}
}
}
return &err
return xErrorCode(rc)
}
func (sqlt *sqlite) getfn(name string) api.Function {
@@ -212,14 +212,10 @@ func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
}
func (sqlt *sqlite) newBytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
if len(b) == 0 {
return 0
}
size := len(b)
if size == 0 {
size = 1
}
ptr := sqlt.new(int64(size))
ptr := sqlt.new(int64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
@@ -288,7 +284,7 @@ func (a *arena) new(size int64) ptr_t {
}
func (a *arena) bytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
if len(b) == 0 {
return 0
}
ptr := a.new(int64(len(b)))

View File

@@ -16,6 +16,9 @@
// #define SQLITE_OMIT_DECLTYPE
// #define SQLITE_OMIT_PROGRESS_CALLBACK
// TODO add this:
// #define SQLITE_ENABLE_API_ARMOR
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY

View File

@@ -125,11 +125,6 @@ func Test_sqlite_newBytes(t *testing.T) {
if got := util.View(sqlite.mod, ptr, int64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
}
func Test_sqlite_newString(t *testing.T) {

140
stmt.go
View File

@@ -106,7 +106,14 @@ func (s *Stmt) Busy() bool {
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt(s.c.handle)
if s.c.interrupt.Err() != nil {
s.err = INTERRUPT
return false
}
return s.step()
}
func (s *Stmt) step() bool {
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
@@ -131,7 +138,11 @@ func (s *Stmt) Err() error {
// Exec is a convenience function that repeatedly calls [Stmt.Step] until it returns false,
// then calls [Stmt.Reset] to reset the statement and get any error that occurred.
func (s *Stmt) Exec() error {
for s.Step() {
if s.c.interrupt.Err() != nil {
return INTERRUPT
}
// TODO: implement this in C.
for s.step() {
}
return s.Reset()
}
@@ -254,13 +265,15 @@ func (s *Stmt) BindText(param int, value string) error {
// BindRawText binds a []byte to the prepared statement as text.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindRawText(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
if len(value) == 0 {
return s.BindText(param, "")
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
@@ -270,13 +283,15 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
// BindBlob binds a []byte to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
if len(value) == 0 {
return s.BindZeroBlob(param, 0)
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_blob_go",
stk_t(s.handle), stk_t(param),
@@ -560,7 +575,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
return s.columnRawBytes(col, ptr, 1)
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -572,10 +587,10 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
return s.columnRawBytes(col, ptr, 0)
}
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
if ptr == 0 {
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE {
@@ -586,7 +601,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col)))
return util.View(s.c.mod, ptr, int64(n))
return util.View(s.c.mod, ptr, int64(n+nul))[:n]
}
// ColumnJSON parses the JSON-encoded value of the result column
@@ -633,22 +648,12 @@ func (s *Stmt) ColumnValue(col int) Value {
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] as string, and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()()
count := int64(len(dest))
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if err := s.c.error(rc); err != nil {
types, ptr, err := s.columns(int64(len(dest)))
if err != nil {
return err
}
types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
@@ -657,26 +662,95 @@ func (s *Stmt) Columns(dest ...any) error {
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, dataPtr)
dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0)
if ptr == 0 {
dest[i] = []byte{}
continue
}
len := util.Read32[int32](s.c.mod, dataPtr+4)
buf := util.View(s.c.mod, ptr, int64(len))
if types[i] == byte(TEXT) {
case byte(TEXT):
len := util.Read32[int32](s.c.mod, ptr+4)
if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len))
dest[i] = string(buf)
} else {
dest[i] = buf
dest[i] = ""
}
case byte(BLOB):
len := util.Read32[int32](s.c.mod, ptr+4)
if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len))
tmp, _ := dest[i].([]byte)
dest[i] = append(tmp[:0], buf...)
} else {
dest[i], _ = dest[i].([]byte)
}
}
dataPtr += 8
ptr += 8
}
return nil
}
// ColumnsRaw populates result columns into the provided slice.
// The slice must have [Stmt.ColumnCount] length.
//
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) ColumnsRaw(dest ...any) error {
types, ptr, err := s.columns(int64(len(dest)))
if err != nil {
return err
}
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
len := util.Read32[int32](s.c.mod, ptr+4)
if len == 0 && types[i] == byte(BLOB) {
dest[i] = []byte{}
} else {
cap := len
if types[i] == byte(TEXT) {
cap++
}
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(cap))[:len]
dest[i] = buf
}
}
ptr += 8
}
return nil
}
func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
defer s.c.arena.mark()()
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if rc == res_t(MISUSE) {
return nil, 0, MISUSE
}
if err := s.c.error(rc); err != nil {
return nil, 0, err
}
return util.View(s.c.mod, typePtr, count), dataPtr, nil
}

View File

@@ -146,10 +146,7 @@ func TestConn_SetInterrupt(t *testing.T) {
}
defer stmt.Close()
go func() {
time.Sleep(time.Millisecond)
cancel()
}()
time.AfterFunc(time.Millisecond, cancel)
// Interrupting works.
err = stmt.Exec()

View File

@@ -89,6 +89,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindRawText(1, nil); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
@@ -103,13 +110,6 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindBlob(1, nil); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindZeroBlob(1, 4); err != nil {
t.Fatal(err)
}
@@ -353,6 +353,31 @@ func TestStmt(t *testing.T) {
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
@@ -403,33 +428,6 @@ func TestStmt(t *testing.T) {
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
@@ -664,6 +662,42 @@ func TestStmt_ColumnValue(t *testing.T) {
}
}
func TestStmt_Columns(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 0, 0.5, 'abc', x'cafe', NULL`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
var dest [5]any
if err := stmt.Columns(dest[:]...); err != nil {
t.Fatal(err)
}
if got := dest[0]; got != int64(0) {
t.Errorf("got %d, want 0", got)
}
if got := dest[1]; got != float64(0.5) {
t.Errorf("got %f, want 0.5", got)
}
if got := dest[2]; got != "abc" {
t.Errorf("got %q, want 'abc'", got)
}
if got := dest[3]; string(got.([]byte)) != "\xCA\xFE" {
t.Errorf("got %q, want x'cafe'", got)
}
if got := dest[4]; got != nil {
t.Errorf("got %q, want nil", got)
}
}
}
func TestStmt_Error(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")

24
txn.go
View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"context"
"errors"
"math/rand"
"runtime"
"strconv"
@@ -21,11 +20,13 @@ type Txn struct {
}
// Begin starts a deferred transaction.
// It panics if a transaction is in-progress.
// For nested transactions, use [Conn.Savepoint].
//
// https://sqlite.org/lang_transaction.html
func (c *Conn) Begin() Txn {
// BEGIN even if interrupted.
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
err := c.exec(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
@@ -120,7 +121,8 @@ func (tx Txn) Commit() error {
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Rollback() error {
return tx.c.txnExecInterrupted(`ROLLBACK`)
// ROLLBACK even if interrupted.
return tx.c.exec(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
@@ -143,7 +145,7 @@ func (c *Conn) Savepoint() Savepoint {
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted(`SAVEPOINT ` + name)
err := c.exec(`SAVEPOINT ` + name)
if err != nil {
panic(err)
}
@@ -199,7 +201,7 @@ func (s Savepoint) Release(errp *error) {
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
if err != nil {
panic(err)
}
@@ -212,17 +214,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
}
func (c *Conn) txnExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
return s.c.exec(`ROLLBACK TO ` + s.name)
}
// TxnState determines the transaction state of a database.

View File

@@ -1,16 +0,0 @@
//go:build !windows
package osutil
import (
"io/fs"
"os"
)
// OpenFile behaves the same as [os.OpenFile],
// except on Windows it sets [syscall.FILE_SHARE_DELETE].
//
// See: https://go.dev/issue/32088#issuecomment-502850674
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

View File

@@ -1,115 +0,0 @@
package osutil
import (
"io/fs"
"os"
. "syscall"
"unsafe"
)
// OpenFile behaves the same as [os.OpenFile],
// except on Windows it sets [syscall.FILE_SHARE_DELETE].
//
// See: https://go.dev/issue/32088#issuecomment-502850674
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT}
}
r, e := syscallOpen(name, flag|O_CLOEXEC, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
// syscallOpen is a copy of [syscall.Open]
// that uses [syscall.FILE_SHARE_DELETE].
//
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd Handle, err error) {
if len(path) == 0 {
return InvalidHandle, ERROR_FILE_NOT_FOUND
}
pathp, err := UTF16PtrFromString(path)
if err != nil {
return InvalidHandle, err
}
var access uint32
switch mode & (O_RDONLY | O_WRONLY | O_RDWR) {
case O_RDONLY:
access = GENERIC_READ
case O_WRONLY:
access = GENERIC_WRITE
case O_RDWR:
access = GENERIC_READ | GENERIC_WRITE
}
if mode&O_CREAT != 0 {
access |= GENERIC_WRITE
}
if mode&O_APPEND != 0 {
access &^= GENERIC_WRITE
access |= FILE_APPEND_DATA
}
sharemode := uint32(FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE)
var sa *SecurityAttributes
if mode&O_CLOEXEC == 0 {
sa = makeInheritSa()
}
var createmode uint32
switch {
case mode&(O_CREAT|O_EXCL) == (O_CREAT | O_EXCL):
createmode = CREATE_NEW
case mode&(O_CREAT|O_TRUNC) == (O_CREAT | O_TRUNC):
createmode = CREATE_ALWAYS
case mode&O_CREAT == O_CREAT:
createmode = OPEN_ALWAYS
case mode&O_TRUNC == O_TRUNC:
createmode = TRUNCATE_EXISTING
default:
createmode = OPEN_EXISTING
}
var attrs uint32 = FILE_ATTRIBUTE_NORMAL
if perm&S_IWRITE == 0 {
attrs = FILE_ATTRIBUTE_READONLY
if createmode == CREATE_ALWAYS {
const _ERROR_BAD_NETPATH = Errno(53)
// We have been asked to create a read-only file.
// If the file already exists, the semantics of
// the Unix open system call is to preserve the
// existing permissions. If we pass CREATE_ALWAYS
// and FILE_ATTRIBUTE_READONLY to CreateFile,
// and the file already exists, CreateFile will
// change the file permissions.
// Avoid that to preserve the Unix semantics.
h, e := CreateFile(pathp, access, sharemode, sa, TRUNCATE_EXISTING, FILE_ATTRIBUTE_NORMAL, 0)
switch e {
case ERROR_FILE_NOT_FOUND, _ERROR_BAD_NETPATH, ERROR_PATH_NOT_FOUND:
// File does not exist. These are the same
// errors as Errno.Is checks for ErrNotExist.
// Carry on to create the file.
default:
// Success or some different error.
return h, e
}
}
}
if createmode == OPEN_EXISTING && access == GENERIC_READ {
// Necessary for opening directory handles.
attrs |= FILE_FLAG_BACKUP_SEMANTICS
}
if mode&O_SYNC != 0 {
const _FILE_FLAG_WRITE_THROUGH = 0x80000000
attrs |= _FILE_FLAG_WRITE_THROUGH
}
if mode&O_NONBLOCK != 0 {
attrs |= FILE_FLAG_OVERLAPPED
}
return CreateFile(pathp, access, sharemode, sa, createmode, attrs, 0)
}
func makeInheritSa() *SecurityAttributes {
var sa SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
return &sa
}

View File

@@ -1,3 +1,4 @@
// Package osutil implements operating system utilities.
package osutil
import (
@@ -19,7 +20,7 @@ type FS struct{}
// Open implements [fs.FS].
func (FS) Open(name string) (fs.File, error) {
return OpenFile(name, os.O_RDONLY, 0)
return os.OpenFile(name, os.O_RDONLY, 0)
}
// ReadFileFS implements [fs.StatFS].
@@ -31,3 +32,10 @@ func (FS) Stat(name string) (fs.FileInfo, error) {
func (FS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}
// OpenFile behaves the same as [os.OpenFile].
//
// Deprecated: use os.OpenFile instead.
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

View File

@@ -1,2 +0,0 @@
// Package osutil implements operating system utilities.
package osutil

View File

@@ -5,5 +5,5 @@ package sql3util
//
// https://sqlite.org/fileformat.html#pages
func ValidPageSize(s int) bool {
return 512 <= s && s <= 65536 && s&(s-1) == 0
return s&(s-1) == 0 && 512 <= s && s <= 65536
}

View File

@@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte {
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
return v.rawBytes(ptr)
return v.rawBytes(ptr, 1)
}
// RawBlob returns the value as a []byte.
@@ -149,16 +149,16 @@ func (v Value) RawText() []byte {
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
return v.rawBytes(ptr)
return v.rawBytes(ptr, 0)
}
func (v Value) rawBytes(ptr ptr_t) []byte {
func (v Value) rawBytes(ptr ptr_t, nul int32) []byte {
if ptr == 0 {
return nil
}
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
return util.View(v.c.mod, ptr, int64(n))
return util.View(v.c.mod, ptr, int64(n+nul))[:n]
}
// Pointer gets the pointer associated with this value,

View File

@@ -49,9 +49,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
n, err = c.File.ReadAt(p, off)
p = p[:n]
// SQLite is reading the header of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
@@ -67,9 +65,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
}
func (c cksmFile) WriteAt(p []byte, off int64) (n int, err error) {
// SQLite is writing the first page of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
@@ -116,9 +112,11 @@ func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpco
c.inCkpt = true
case _FCNTL_CKPT_DONE:
c.inCkpt = false
}
if rc := vfsFileControlImpl(ctx, mod, c, op, pArg); rc != _NOTFOUND {
return rc
case _FCNTL_PRAGMA:
rc := vfsFileControlImpl(ctx, mod, c, op, pArg)
if rc != _NOTFOUND {
return rc
}
}
return vfsFileControlImpl(ctx, mod, c.File, op, pArg)
}
@@ -135,6 +133,14 @@ func (f *cksmFlags) init(header *[100]byte) {
}
}
func isHeader(isDB bool, p []byte, off int64) bool {
check := sql3util.ValidPageSize(len(p))
if isDB {
check = off == 0 && len(p) >= 100
}
return check && bytes.HasPrefix(p, []byte("SQLite format 3\000"))
}
func cksmCompute(a []byte) (cksm [8]byte) {
var s1, s2 uint32
for len(a) >= 8 {

View File

@@ -6,9 +6,8 @@ import (
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"github.com/ncruces/go-sqlite3/util/osutil"
)
type vfsOS struct{}
@@ -40,7 +39,7 @@ func (vfsOS) Delete(path string, syncDir bool) error {
if err != nil {
return err
}
if canSyncDirs && syncDir {
if isUnix && syncDir {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
@@ -96,7 +95,7 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
if name == nil {
f, err = os.CreateTemp(os.Getenv("SQLITE_TMPDIR"), "*.db")
} else {
f, err = osutil.OpenFile(name.String(), oflags, 0666)
f, err = os.OpenFile(name.String(), oflags, 0666)
}
if err != nil {
if name == nil {
@@ -118,7 +117,7 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
return nil, flags, _IOERR_FSTAT
}
}
if flags&OPEN_DELETEONCLOSE != 0 {
if isUnix && flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
@@ -127,7 +126,8 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
psow: true,
atomic: osBatchAtomic(f),
readOnly: flags&OPEN_READONLY != 0,
syncDir: canSyncDirs && isCreate && isJournl,
syncDir: isUnix && isCreate && isJournl,
delete: !isUnix && flags&OPEN_DELETEONCLOSE != 0,
shm: NewSharedMemory(name.String()+"-shm", flags),
}
return &file, flags, nil
@@ -141,6 +141,7 @@ type vfsFile struct {
keepWAL bool
syncDir bool
atomic bool
delete bool
psow bool
}
@@ -154,6 +155,9 @@ var (
)
func (f *vfsFile) Close() error {
if f.delete {
defer os.Remove(f.Name())
}
if f.shm != nil {
f.shm.Close()
}
@@ -177,7 +181,7 @@ func (f *vfsFile) Sync(flags SyncFlag) error {
if err != nil {
return err
}
if canSyncDirs && f.syncDir {
if isUnix && f.syncDir {
f.syncDir = false
d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil {
@@ -208,6 +212,9 @@ func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
if f.psow {
ret |= IOCAP_POWERSAFE_OVERWRITE
}
if runtime.GOOS == "windows" {
ret |= IOCAP_UNDELETABLE_WHEN_OPEN
}
return ret
}
@@ -216,6 +223,9 @@ func (f *vfsFile) SizeHint(size int64) error {
}
func (f *vfsFile) HasMoved() (bool, error) {
if runtime.GOOS == "windows" {
return false, nil
}
fi, err := f.Stat()
if err != nil {
return false, err

View File

@@ -60,10 +60,10 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
}
var err error
switch {
case timeout < 0:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock)
default:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
case timeout < 0:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock)
}
return osLockErrorCode(err, def)
}

View File

@@ -8,8 +8,8 @@ import (
)
const (
isUnix = false
_O_NOFOLLOW = 0
canSyncDirs = false
)
func osAccess(path string, flags AccessFlag) error {

View File

@@ -10,8 +10,8 @@ import (
)
const (
isUnix = true
_O_NOFOLLOW = unix.O_NOFOLLOW
canSyncDirs = true
)
func osAccess(path string, flags AccessFlag) error {

View File

@@ -135,12 +135,10 @@ func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _Error
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
switch {
case timeout == 0:
default:
err = osLockEx(file, flags|windows.LOCKFILE_FAIL_IMMEDIATELY, start, len)
case timeout < 0:
err = osLockEx(file, flags, start, len)
default:
err = osLockExTimeout(file, flags, start, len, timeout)
}
return osLockErrorCode(err, def)
}
@@ -162,37 +160,6 @@ func osLockEx(file *os.File, flags, start, len uint32) error {
0, len, 0, &windows.Overlapped{Offset: start})
}
func osLockExTimeout(file *os.File, flags, start, len uint32, timeout time.Duration) error {
event, err := windows.CreateEvent(nil, 1, 0, nil)
if err != nil {
return err
}
defer windows.CloseHandle(event)
fd := windows.Handle(file.Fd())
overlapped := &windows.Overlapped{
Offset: start,
HEvent: event,
}
err = windows.LockFileEx(fd, flags, 0, len, 0, overlapped)
if err != windows.ERROR_IO_PENDING {
return err
}
ms := (timeout + time.Millisecond - 1) / time.Millisecond
rc, err := windows.WaitForSingleObject(event, uint32(ms))
if rc == windows.WAIT_OBJECT_0 {
return nil
}
defer windows.CancelIoEx(fd, overlapped)
if err != nil {
return err
}
return windows.Errno(rc)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK

View File

@@ -7,14 +7,11 @@ import (
"io"
"os"
"sync"
"syscall"
"time"
"github.com/tetratelabs/wazero/api"
"golang.org/x/sys/windows"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
)
type vfsShm struct {
@@ -33,8 +30,6 @@ type vfsShm struct {
sync.Mutex
}
var _ blockingSharedMemory = &vfsShm{}
func (s *vfsShm) Close() error {
// Unmap regions.
for _, r := range s.regions {
@@ -48,8 +43,7 @@ func (s *vfsShm) Close() error {
func (s *vfsShm) shmOpen() _ErrorCode {
if s.File == nil {
f, err := osutil.OpenFile(s.path,
os.O_RDWR|os.O_CREATE|syscall.O_NONBLOCK, 0666)
f, err := os.OpenFile(s.path, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
return _CANTOPEN
}
@@ -67,7 +61,7 @@ func (s *vfsShm) shmOpen() _ErrorCode {
return _IOERR_SHMOPEN
}
}
rc := osReadLock(s.File, _SHM_DMS, 1, time.Millisecond)
rc := osReadLock(s.File, _SHM_DMS, 1, 0)
s.fileLock = rc == _OK
return rc
}
@@ -135,11 +129,6 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext
}
func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) {
var timeout time.Duration
if s.blocking {
timeout = time.Millisecond
}
switch {
case flags&_SHM_LOCK != 0:
defer s.shmAcquire(&rc)
@@ -151,9 +140,9 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) {
case flags&_SHM_UNLOCK != 0:
return osUnlock(s.File, _SHM_BASE+uint32(offset), uint32(n))
case flags&_SHM_SHARED != 0:
return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout)
return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0)
case flags&_SHM_EXCLUSIVE != 0:
return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout)
return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0)
default:
panic(util.AssertErr())
}
@@ -184,7 +173,3 @@ func (s *vfsShm) shmUnmap(delete bool) {
os.Remove(s.path)
}
}
func (s *vfsShm) shmEnableBlocking(block bool) {
s.blocking = block
}

View File

@@ -79,9 +79,12 @@ func implements[T any](typ reflect.Type) bool {
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVTab(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
defer c.arena.mark()()
sqlPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(sqlPtr)))
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
return c.error(rc)
}