From 1b3823483f4d132edb3b3206d2698a4632a27315 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 27 Feb 2023 13:45:32 +0000 Subject: [PATCH] Incremental blobs. --- blob.go | 14 +++++++-- conn.go | 19 ++++++++--- driver/driver.go | 17 ++++++++-- driver/driver_test.go | 38 ---------------------- driver_test.go | 70 +++++++++++++++++++++++++++++++++++++++++ tests/blob_test.go | 73 +++++++++++++++++++++++++++++++++++++++++-- 6 files changed, 183 insertions(+), 48 deletions(-) create mode 100644 driver_test.go diff --git a/blob.go b/blob.go index f505d69..2ff84d8 100644 --- a/blob.go +++ b/blob.go @@ -22,7 +22,7 @@ var _ io.ReadWriteSeeker = &Blob{} // OpenBlob opens a BLOB for incremental I/O. // // https://www.sqlite.org/c3ref/blob_open.html -func (c *Conn) OpenBlob(db, table, column string, row uint64, write bool) (*Blob, error) { +func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { defer c.arena.reset() blobPtr := c.arena.new(ptrlen) dbPtr := c.arena.string(db) @@ -36,7 +36,7 @@ func (c *Conn) OpenBlob(db, table, column string, row uint64, write bool) (*Blob r := c.call(c.api.blobOpen, uint64(c.handle), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), - row, flags, uint64(blobPtr)) + uint64(row), flags, uint64(blobPtr)) if err := c.error(r[0]); err != nil { return nil, err @@ -144,3 +144,13 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { b.offset = offset return offset, nil } + +// Reopen moves a BLOB handle to a new row of the same database table. +// +// https://www.sqlite.org/c3ref/blob_reopen.html +func (b *Blob) Reopen(row int64) error { + r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)) + b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0]) + b.offset = 0 + return b.c.error(r[0]) +} diff --git a/conn.go b/conn.go index 6d5cdbf..b7193e2 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "database/sql/driver" "math" "sync" @@ -165,9 +166,9 @@ func (c *Conn) GetAutocommit() bool { // on the database connection. // // https://www.sqlite.org/c3ref/last_insert_rowid.html -func (c *Conn) LastInsertRowID() uint64 { +func (c *Conn) LastInsertRowID() int64 { r := c.call(c.api.lastRowid, uint64(c.handle)) - return r[0] + return int64(r[0]) } // Changes returns the number of rows modified, inserted or deleted @@ -175,9 +176,9 @@ func (c *Conn) LastInsertRowID() uint64 { // on the database connection. // // https://www.sqlite.org/c3ref/changes.html -func (c *Conn) Changes() uint64 { +func (c *Conn) Changes() int64 { r := c.call(c.api.changes, uint64(c.handle)) - return r[0] + return int64(r[0]) } // SetInterrupt interrupts a long-running query when a context is done. @@ -409,3 +410,13 @@ func (a *arena) string(s string) uint32 { a.c.mem.writeString(ptr, s) return ptr } + +// DriverConn is implemented by the SQLite database/sql driver connection. +type DriverConn interface { + driver.ConnBeginTx + driver.ExecerContext + driver.ConnPrepareContext + + Savepoint() (release func(*error)) + OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) +} diff --git a/driver/driver.go b/driver/driver.go index 13ec009..bedef62 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -71,6 +71,7 @@ var ( // Ensure these interfaces are implemented: _ driver.ExecerContext = conn{} _ driver.ConnBeginTx = conn{} + _ sqlite3.DriverConn = conn{} ) func (c conn) Close() error { @@ -140,6 +141,10 @@ func (c conn) Prepare(query string) (driver.Stmt, error) { return stmt{s, c.conn}, nil } +func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { + return c.Prepare(query) +} + func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if len(args) != 0 { // Slow path. @@ -155,11 +160,19 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named } return result{ - int64(c.conn.LastInsertRowID()), - int64(c.conn.Changes()), + c.conn.LastInsertRowID(), + c.conn.Changes(), }, nil } +func (c conn) Savepoint() (release func(*error)) { + return c.conn.Savepoint() +} + +func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) { + return c.conn.OpenBlob(db, table, column, row, write) +} + type stmt struct { stmt *sqlite3.Stmt conn *sqlite3.Conn diff --git a/driver/driver_test.go b/driver/driver_test.go index 514b70c..40901cf 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -307,41 +307,3 @@ func Test_QueryRow_blob_null(t *testing.T) { } } } - -func Test_ZeroBlob(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - conn, err := db.Conn(ctx) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - _, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`) - if err != nil { - t.Fatal(err) - } - - _, err = conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(4)) - if err != nil { - t.Fatal(err) - } - - var got []byte - err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&got) - if err != nil { - t.Fatal(err) - } - if string(got) != "\x00\x00\x00\x00" { - t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) - } -} diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..5e569d2 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,70 @@ +package sqlite3_test + +import ( + "context" + "database/sql" + "fmt" + "log" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func ExampleDriverConn() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + _, err = conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + log.Fatal(err) + } + + r, err := conn.ExecContext(ctx, `INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(11)) + if err != nil { + log.Fatal(err) + } + + id, err := r.LastInsertId() + if err != nil { + log.Fatal(err) + } + + err = conn.Raw(func(driverConn any) error { + conn := driverConn.(sqlite3.DriverConn) + defer conn.Savepoint()(&err) + + blob, err := conn.OpenBlob("main", "test", "col", id, true) + if err != nil { + return err + } + defer blob.Close() + + _, err = fmt.Fprint(blob, "Hello BLOB!") + return err + }) + if err != nil { + log.Fatal(err) + } + + var msg string + err = conn.QueryRowContext(ctx, `SELECT col FROM test`).Scan(&msg) + if err != nil { + log.Fatal(err) + } + fmt.Println(msg) + // Output: + // Hello BLOB! +} diff --git a/tests/blob_test.go b/tests/blob_test.go index 3554bf8..391d945 100644 --- a/tests/blob_test.go +++ b/tests/blob_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "errors" + "fmt" "io" "testing" @@ -111,7 +112,7 @@ func TestBlob_invalid(t *testing.T) { } } -func TestBlob_readonly(t *testing.T) { +func TestBlob_Write_readonly(t *testing.T) { t.Parallel() db, err := sqlite3.Open(":memory:") @@ -142,7 +143,7 @@ func TestBlob_readonly(t *testing.T) { } } -func TestBlob_expired(t *testing.T) { +func TestBlob_Read_expired(t *testing.T) { t.Parallel() db, err := sqlite3.Open(":memory:") @@ -226,3 +227,71 @@ func TestBlob_Seek(t *testing.T) { t.Fatal("want error") } } + +func TestBlob_Reopen(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + var rowids []int64 + for i := 0; i < 100; i++ { + err = db.Exec(`INSERT INTO test VALUES (zeroblob(10))`) + if err != nil { + t.Fatal(err) + } + rowids = append(rowids, db.LastInsertRowID()) + } + + var blob *sqlite3.Blob + + for i, rowid := range rowids { + if i > 0 { + err = blob.Reopen(rowid) + } else { + blob, err = db.OpenBlob("main", "test", "col", rowid, true) + } + if err != nil { + t.Fatal(err) + } + + _, err = fmt.Fprintf(blob, "blob %d\n", i) + if err != nil { + t.Fatal(err) + } + } + if err := blob.Close(); err != nil { + t.Fatal(err) + } + + for i, rowid := range rowids { + if i > 0 { + err = blob.Reopen(rowid) + } else { + blob, err = db.OpenBlob("main", "test", "col", rowid, false) + } + if err != nil { + t.Fatal(err) + } + + var got int + _, err = fmt.Fscanf(blob, "blob %d\n", &got) + if err != nil { + t.Fatal(err) + } + if got != i { + t.Errorf("got %d, want %d", got, i) + } + } + if err := blob.Close(); err != nil { + t.Fatal(err) + } +}