From 09c7c7af3fe954cb6684f5dae121a30b93f00074 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 26 Jan 2023 11:12:00 +0000 Subject: [PATCH] More tests. --- .github/coverage.html | 31 +++++++---- .github/coverage.svg | 2 +- conn.go | 11 +++- conn_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 13 deletions(-) create mode 100644 conn_test.go diff --git a/.github/coverage.html b/.github/coverage.html index 17fbf6e..d5becc1 100644 --- a/.github/coverage.html +++ b/.github/coverage.html @@ -59,7 +59,7 @@ - + @@ -233,6 +233,7 @@ func compile() { import ( "bytes" "context" + "math" "strconv" "github.com/tetratelabs/wazero" @@ -378,7 +379,7 @@ func (c *Conn) error(rc uint64) error { } func (c *Conn) free(ptr uint32) { - if ptr == 0 { + if ptr == 0 { return } _, err := c.api.free.Call(c.ctx, uint64(ptr)) @@ -393,18 +394,18 @@ func (c *Conn) new(len uint32) uint32 { panic(err) } ptr := uint32(r[0]) - if ptr == 0 || ptr >= c.memory.Size() { + if ptr == 0 || ptr >= c.memory.Size() { panic(oomErr) } return ptr } -func (c *Conn) newBytes(s []byte) uint32 { - if s == nil { +func (c *Conn) newBytes(s []byte) uint32 { + if s == nil { return 0 } - siz := uint32(len(s)) + siz := uint32(len(s)) ptr := c.new(siz) mem, ok := c.memory.Read(ptr, siz) if !ok { @@ -412,7 +413,7 @@ func (c *Conn) newBytes(s []byte) uint32 { panic(rangeErr) } - copy(mem, s) + copy(mem, s) return ptr } @@ -435,17 +436,25 @@ func (c *Conn) getString(ptr, maxlen uint32) string func getString(memory api.Memory, ptr, maxlen uint32) string { - if ptr == 0 { + if ptr == 0 { panic(nilErr) } - mem, ok := memory.Read(ptr, maxlen+1) - if !ok { + switch maxlen { + case 0: + return "" + case math.MaxUint32: + // + default: + maxlen = maxlen + 1 + } + mem, ok := memory.Read(ptr, maxlen) + if !ok { mem, ok = memory.Read(ptr, memory.Size()-ptr) if !ok { panic(rangeErr) } } - if i := bytes.IndexByte(mem, 0); i < 0 { + if i := bytes.IndexByte(mem, 0); i < 0 { panic(noNulErr) } else { return string(mem[:i]) diff --git a/.github/coverage.svg b/.github/coverage.svg index f2b64de..f37c642 100644 --- a/.github/coverage.svg +++ b/.github/coverage.svg @@ -1 +1 @@ -coverage: 56.6%coverage56.6% \ No newline at end of file +coverage: 59.0%coverage59.0% \ No newline at end of file diff --git a/conn.go b/conn.go index 56ed013..3c900de 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package sqlite3 import ( "bytes" "context" + "math" "strconv" "github.com/tetratelabs/wazero" @@ -208,7 +209,15 @@ func getString(memory api.Memory, ptr, maxlen uint32) string { if ptr == 0 { panic(nilErr) } - mem, ok := memory.Read(ptr, maxlen+1) + switch maxlen { + case 0: + return "" + case math.MaxUint32: + // + default: + maxlen = maxlen + 1 + } + mem, ok := memory.Read(ptr, maxlen) if !ok { mem, ok = memory.Read(ptr, memory.Size()-ptr) if !ok { diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..eb62896 --- /dev/null +++ b/conn_test.go @@ -0,0 +1,123 @@ +package sqlite3 + +import ( + "bytes" + "math" + "testing" +) + +func TestConn_new(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + defer func() { _ = recover() }() + db.new(math.MaxUint32) + t.Errorf("should have panicked") +} + +func TestConn_newBytes(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ptr := db.newBytes(nil) + if ptr != 0 { + t.Errorf("want nullptr got %x", ptr) + } + + buf := []byte("sqlite3") + ptr = db.newBytes(buf) + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + want := buf + if got, ok := db.memory.Read(ptr, uint32(len(want))); !ok || !bytes.Equal(want, got) { + t.Errorf("want %q got %q", want, got) + } +} + +func TestConn_newString(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ptr := db.newString("") + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + str := "sqlite3\000sqlite3" + ptr = db.newString(str) + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + want := str + "\000" + if got, ok := db.memory.Read(ptr, uint32(len(want))); !ok || want != string(got) { + t.Errorf("want %q got %q", want, got) + } +} + +func TestConn_getString(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ptr := db.newString("") + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + str := "sqlite3" + "\000 drop this" + ptr = db.newString(str) + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + want := "sqlite3" + if got := db.getString(ptr, math.MaxUint32); want != got { + t.Errorf("want %q got %q", want, got) + } + if got := db.getString(ptr, 0); got != "" { + t.Errorf("want empty got %q", got) + } + + func() { + defer func() { _ = recover() }() + db.getString(ptr, uint32(len(want)/2)) + t.Errorf("should have panicked") + }() + + func() { + defer func() { _ = recover() }() + db.getString(0, math.MaxUint32) + t.Errorf("should have panicked") + }() +} + +func TestConn_free(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + db.free(0) + + ptr := db.new(0) + if ptr == 0 { + t.Errorf("want a pointer got nullptr") + } + + db.free(ptr) +}