From f360c77a7810d9de86eba8a2e6f51899c01703b3 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 21 Apr 2023 13:31:45 +0100 Subject: [PATCH] Optimize blobs. (#10) --- blob.go | 112 +++++++++++++++++++++++++++++++++----- module.go | 9 ++++ tests/blob_test.go | 131 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 233 insertions(+), 19 deletions(-) diff --git a/blob.go b/blob.go index cf00b15..794980e 100644 --- a/blob.go +++ b/blob.go @@ -86,14 +86,14 @@ func (b *Blob) Read(p []byte) (n int, err error) { return 0, io.EOF } - want := int64(len(p)) avail := b.bytes - b.offset + want := int64(len(p)) if want > avail { want = avail } - ptr := b.c.new(uint64(want)) - defer b.c.free(ptr) + defer b.c.arena.reset() + ptr := b.c.arena.new(uint64(want)) r := b.c.call(b.c.api.blobRead, uint64(b.handle), uint64(ptr), uint64(want), uint64(b.offset)) @@ -101,30 +101,68 @@ func (b *Blob) Read(p []byte) (n int, err error) { if err != nil { return 0, err } - - mem := util.View(b.c.mod, ptr, uint64(want)) - copy(p, mem) b.offset += want if b.offset >= b.bytes { err = io.EOF } + + copy(p, util.View(b.c.mod, ptr, uint64(want))) return int(want), err } +// WriteTo implements the [io.WriterTo] interface. +// +// https://www.sqlite.org/c3ref/blob_read.html +func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { + if b.offset >= b.bytes { + return 0, nil + } + + avail := b.bytes - b.offset + want := int64(65536) + if want > avail { + want = avail + } + + ptr := b.c.new(uint64(want)) + defer b.c.free(ptr) + + for want > 0 { + r := b.c.call(b.c.api.blobRead, uint64(b.handle), + uint64(ptr), uint64(want), uint64(b.offset)) + err = b.c.error(r[0]) + if err != nil { + return n, err + } + + mem := util.View(b.c.mod, ptr, uint64(want)) + m, err := w.Write(mem[:want]) + b.offset += int64(m) + n += int64(m) + if err != nil { + return n, err + } + if int64(m) != want { + return n, io.ErrShortWrite + } + + avail = b.bytes - b.offset + if want > avail { + want = avail + } + } + return n, nil +} + // Write implements the [io.Writer] interface. // // https://www.sqlite.org/c3ref/blob_write.html func (b *Blob) Write(p []byte) (n int, err error) { - offset := b.offset - if offset > b.bytes { - offset = b.bytes - } - - ptr := b.c.newBytes(p) - defer b.c.free(ptr) + defer b.c.arena.reset() + ptr := b.c.arena.bytes(p) r := b.c.call(b.c.api.blobWrite, uint64(b.handle), - uint64(ptr), uint64(len(p)), uint64(offset)) + uint64(ptr), uint64(len(p)), uint64(b.offset)) err = b.c.error(r[0]) if err != nil { return 0, err @@ -133,6 +171,52 @@ func (b *Blob) Write(p []byte) (n int, err error) { return len(p), nil } +// ReadFrom implements the [io.ReaderFrom] interface. +// +// https://www.sqlite.org/c3ref/blob_write.html +func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) { + avail := b.bytes - b.offset + want := int64(65536) + if want > avail { + want = avail + } + if want < 1 { + want = 1 + } + + ptr := b.c.new(uint64(want)) + defer b.c.free(ptr) + + for { + mem := util.View(b.c.mod, ptr, uint64(want)) + m, err := r.Read(mem[:want]) + if m > 0 { + r := b.c.call(b.c.api.blobWrite, uint64(b.handle), + uint64(ptr), uint64(m), uint64(b.offset)) + err := b.c.error(r[0]) + if err != nil { + return n, err + } + b.offset += int64(m) + n += int64(m) + } + if err == io.EOF { + return n, nil + } + if err != nil { + return n, err + } + + avail = b.bytes - b.offset + if want > avail { + want = avail + } + if want < 1 { + want = 1 + } + } +} + // Seek implements the [io.Seeker] interface. func (b *Blob) Seek(offset int64, whence int) (int64, error) { switch whence { diff --git a/module.go b/module.go index a44b9e4..e5e1335 100644 --- a/module.go +++ b/module.go @@ -288,6 +288,15 @@ func (a *arena) new(size uint64) uint32 { return ptr } +func (a *arena) bytes(b []byte) uint32 { + if b == nil { + return 0 + } + ptr := a.new(uint64(len(b))) + util.WriteBytes(a.m.mod, ptr, b) + return ptr +} + func (a *arena) string(s string) uint32 { ptr := a.new(uint64(len(s) + 1)) util.WriteString(a.m.mod, ptr, s) diff --git a/tests/blob_test.go b/tests/blob_test.go index 391d945..1aceae8 100644 --- a/tests/blob_test.go +++ b/tests/blob_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "fmt" + "hash/adler32" "io" "testing" @@ -48,17 +49,17 @@ func TestBlob(t *testing.T) { t.Fatal(err) } - _, err = io.Copy(blob, bytes.NewReader(data[:size/2])) + _, err = blob.Write(data[:size/2]) if err != nil { t.Fatal(err) } - _, err = io.Copy(blob, bytes.NewReader(data[:])) - if !errors.Is(err, sqlite3.ERROR) { - t.Fatal("want error") + n, err := blob.Write(data[:]) + if n != 0 || !errors.Is(err, sqlite3.ERROR) { + t.Fatalf("got (%d, %v), want (0, ERROR)", n, err) } - _, err = io.Copy(blob, bytes.NewReader(data[size/2:size])) + _, err = blob.Write(data[size/2 : size]) if err != nil { t.Fatal(err) } @@ -87,6 +88,126 @@ func TestBlob(t *testing.T) { } } +func TestBlob_large(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO test VALUES (zeroblob(1000000))`) + if err != nil { + t.Fatal(err) + } + + blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true) + if err != nil { + t.Fatal(err) + } + defer blob.Close() + + size := blob.Size() + if size != 1000000 { + t.Errorf("got %d, want 1000000", size) + } + + hash := adler32.New() + _, err = io.CopyN(blob, io.TeeReader(rand.Reader, hash), 1000000) + if err != nil { + t.Fatal(err) + } + + _, err = blob.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + + want := hash.Sum32() + hash.Reset() + _, err = io.Copy(hash, blob) + if err != nil { + t.Fatal(err) + } + + if got := hash.Sum32(); got != want { + t.Fatalf("got %d, want %d", got, want) + } + + if err := blob.Close(); err != nil { + t.Fatal(err) + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } +} + +func TestBlob_overflow(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`) + if err != nil { + t.Fatal(err) + } + + blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true) + if err != nil { + t.Fatal(err) + } + defer blob.Close() + + n, err := blob.ReadFrom(rand.Reader) + if n != 1024 || !errors.Is(err, sqlite3.ERROR) { + t.Fatalf("got (%d, %v), want (0, ERROR)", n, err) + } + + n, err = blob.ReadFrom(rand.Reader) + if n != 0 || !errors.Is(err, sqlite3.ERROR) { + t.Fatalf("got (%d, %v), want (0, ERROR)", n, err) + } + + _, err = blob.Seek(-128, io.SeekEnd) + if err != nil { + t.Fatal(err) + } + + n, err = blob.WriteTo(io.Discard) + if n != 128 || err != nil { + t.Fatalf("got (%d, %v), want (128, nil)", n, err) + } + + n, err = blob.WriteTo(io.Discard) + if n != 0 || err != nil { + t.Fatalf("got (%d, %v), want (0, nil)", n, err) + } + + if err := blob.Close(); err != nil { + t.Fatal(err) + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } +} + func TestBlob_invalid(t *testing.T) { t.Parallel()