Towards database/sql.

This commit is contained in:
Nuno Cruces
2023-02-14 18:21:18 +00:00
parent 0d59065719
commit 78fd0cbee5
10 changed files with 205 additions and 8 deletions

16
api.go
View File

@@ -44,18 +44,24 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) {
step: getFun("sqlite3_step"),
exec: getFun("sqlite3_exec"),
clearBindings: getFun("sqlite3_clear_bindings"),
bindCount: getFun("sqlite3_bind_parameter_count"),
bindInteger: getFun("sqlite3_bind_int64"),
bindFloat: getFun("sqlite3_bind_double"),
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindNull: getFun("sqlite3_bind_null"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
columnInteger: getFun("sqlite3_column_int64"),
columnFloat: getFun("sqlite3_column_double"),
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
columnType: getFun("sqlite3_column_type"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
interrupt: getFun("sqlite3_interrupt"),
},
}
if err != nil {
@@ -80,16 +86,22 @@ type sqliteAPI struct {
step api.Function
exec api.Function
clearBindings api.Function
bindCount api.Function
bindInteger api.Function
bindFloat api.Function
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindNull api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
columnInteger api.Function
columnFloat api.Function
columnText api.Function
columnBlob api.Function
columnBytes api.Function
columnType api.Function
lastRowid api.Function
changes api.Function
interrupt api.Function
}

58
conn.go
View File

@@ -14,6 +14,9 @@ type Conn struct {
mem memory
arena arena
handle uint32
waiter chan struct{}
done <-chan struct{}
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
@@ -70,6 +73,8 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(nil)
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
return err
@@ -83,6 +88,57 @@ func (c *Conn) Close() error {
return c.mem.mod.Close(c.ctx)
}
// SetInterrupt interrupts a long-running query when done is closed.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until done is reset by another call to SetInterrupt.
//
// Typically, done is provided by [context.Context.Done]:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx.Done())
// defer cancel()
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
old = c.done
c.done = done
if done == nil {
return old
}
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter:
// Waiter was cancelled.
case <-done:
// Done was closed.
// Because it doesn't touch the C stack,
// sqlite3_interrupt is safe to call from a goroutine.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
// Wait for the next call to SetInterrupt.
<-waiter // Waiter was cancelled.
}
// Signal that the waiter is finished.
waiter <- struct{}{}
}()
return old
}
// Exec is a convenience function that allows an application to run
// multiple statements of SQL without having to use a lot of code.
//
@@ -111,9 +167,9 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),

View File

@@ -2,9 +2,11 @@ package sqlite3
import (
"bytes"
"context"
"errors"
"math"
"testing"
"time"
)
func TestConn_Close(t *testing.T) {
@@ -19,7 +21,7 @@ func TestConn_Close_BUSY(t *testing.T) {
}
defer db.Close()
stmt, _, err := db.Prepare("BEGIN")
stmt, _, err := db.Prepare(`BEGIN`)
if err != nil {
t.Fatal(err)
}
@@ -41,6 +43,54 @@ func TestConn_Close_BUSY(t *testing.T) {
}
}
func TestConn_Interrupt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`
WITH RECURSIVE
fibonacci (curr, next)
AS (
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 10e6
)
SELECT min(curr) FROM fibonacci
`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
db.SetInterrupt(ctx.Done())
defer cancel()
for stmt.Step() {
}
err = stmt.Err()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
db.SetInterrupt(nil)
}
func TestConn_Prepare_Empty(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
@@ -48,7 +98,7 @@ func TestConn_Prepare_Empty(t *testing.T) {
}
defer db.Close()
stmt, _, err := db.Prepare("")
stmt, _, err := db.Prepare(``)
if err != nil {
t.Fatal(err)
}
@@ -68,7 +118,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
var serr *Error
_, _, err = db.Prepare("SELECT")
_, _, err = db.Prepare(`SELECT`)
if err == nil {
t.Fatal("want error")
}
@@ -82,7 +132,7 @@ func TestConn_Prepare_Invalid(t *testing.T) {
t.Error("got message: ", got)
}
_, _, err = db.Prepare("SELECT * FRM sqlite_schema")
_, _, err = db.Prepare(`SELECT * FRM sqlite_schema`)
if err == nil {
t.Fatal("want error")
}

44
driver/driver.go Normal file
View File

@@ -0,0 +1,44 @@
//go:build todo
// Package driver provides a database/sql driver for SQLite.
package driver
import (
"database/sql"
"database/sql/driver"
"github.com/ncruces/go-sqlite3"
)
func init() {
sql.Register("sqlite3", sqlite{})
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI)
if err != nil {
return nil, err
}
return conn{c}, nil
}
type conn struct{ *sqlite3.Conn }
type stmt struct{ *sqlite3.Stmt }
func (c conn) Begin() (driver.Tx, error)
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, _, err := c.Conn.Prepare(query)
if err != nil {
return nil, err
}
return stmt{s}, nil
}
func (s stmt) NumInput() int
func (s stmt) Exec(args []driver.Value) (driver.Result, error)
func (s stmt) Query(args []driver.Value) (driver.Rows, error)

View File

@@ -28,15 +28,21 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \
-Wl,--export=sqlite3_step \
-Wl,--export=sqlite3_exec \
-Wl,--export=sqlite3_clear_bindings \
-Wl,--export=sqlite3_bind_parameter_count \
-Wl,--export=sqlite3_bind_int64 \
-Wl,--export=sqlite3_bind_double \
-Wl,--export=sqlite3_bind_text64 \
-Wl,--export=sqlite3_bind_blob64 \
-Wl,--export=sqlite3_bind_zeroblob64 \
-Wl,--export=sqlite3_bind_null \
-Wl,--export=sqlite3_column_count \
-Wl,--export=sqlite3_column_name \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_column_int64 \
-Wl,--export=sqlite3_column_double \
-Wl,--export=sqlite3_column_text \
-Wl,--export=sqlite3_column_blob \
-Wl,--export=sqlite3_column_bytes \
-Wl,--export=sqlite3_column_type \
-Wl,--export=sqlite3_last_insert_rowid \
-Wl,--export=sqlite3_changes64 \
-Wl,--export=sqlite3_interrupt \

Binary file not shown.

View File

@@ -5,6 +5,9 @@
#define SQLITE_OS_OTHER 1
#define SQLITE_BYTEORDER 1234
#define HAVE_STDINT_H 1
#define HAVE_INTTYPES_H 1
#define HAVE_ISNAN 1
#define HAVE_USLEEP 1
#define HAVE_LOCALTIME_S 1
@@ -25,6 +28,16 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Recommended Extensions
// #define SQLITE_ENABLE_MATH_FUNCTIONS 1
// #define SQLITE_ENABLE_FTS3 1
// #define SQLITE_ENABLE_FTS3_PARENTHESIS 1
// #define SQLITE_ENABLE_FTS4 1
// #define SQLITE_ENABLE_FTS5 1
// #define SQLITE_ENABLE_RTREE 1
// #define SQLITE_ENABLE_GEOPOLY 1
// Need this to access WAL databases without the use of shared memory.
#define SQLITE_DEFAULT_LOCKING_MODE 1

12
stmt.go
View File

@@ -95,6 +95,18 @@ func (s *Stmt) Exec() error {
return s.Reset()
}
// BindCount gets the number of SQL parameters in a prepared statement.
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
r, err := s.c.api.bindCount.Call(s.c.ctx,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
// BindBool binds a bool to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// SQLite does not have a separate boolean storage class.

View File

@@ -23,6 +23,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.BindCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
err = stmt.BindBool(1, false)
if err != nil {
t.Fatal(err)