diff --git a/.github/coverage.html b/.github/coverage.html
index 17fbf6e..d5becc1 100644
--- a/.github/coverage.html
+++ b/.github/coverage.html
@@ -59,7 +59,7 @@
-
+
@@ -233,6 +233,7 @@ func compile() {
import (
"bytes"
"context"
+ "math"
"strconv"
"github.com/tetratelabs/wazero"
@@ -378,7 +379,7 @@ func (c *Conn) error(rc uint64) error {
}
func (c *Conn) free(ptr uint32) {
- if ptr == 0 {
+ if ptr == 0 {
return
}
_, err := c.api.free.Call(c.ctx, uint64(ptr))
@@ -393,18 +394,18 @@ func (c *Conn) new(len uint32) uint32 {
panic(err)
}
ptr := uint32(r[0])
- if ptr == 0 || ptr >= c.memory.Size() {
+ if ptr == 0 || ptr >= c.memory.Size() {
panic(oomErr)
}
return ptr
}
-func (c *Conn) newBytes(s []byte) uint32 {
- if s == nil {
+func (c *Conn) newBytes(s []byte) uint32 {
+ if s == nil {
return 0
}
- siz := uint32(len(s))
+ siz := uint32(len(s))
ptr := c.new(siz)
mem, ok := c.memory.Read(ptr, siz)
if !ok {
@@ -412,7 +413,7 @@ func (c *Conn) newBytes(s []byte) uint32 {
panic(rangeErr)
}
- copy(mem, s)
+ copy(mem, s)
return ptr
}
@@ -435,17 +436,25 @@ func (c *Conn) getString(ptr, maxlen uint32) string
func getString(memory api.Memory, ptr, maxlen uint32) string {
- if ptr == 0 {
+ if ptr == 0 {
panic(nilErr)
}
- mem, ok := memory.Read(ptr, maxlen+1)
- if !ok {
+ switch maxlen {
+ case 0:
+ return ""
+ case math.MaxUint32:
+ //
+ default:
+ maxlen = maxlen + 1
+ }
+ mem, ok := memory.Read(ptr, maxlen)
+ if !ok {
mem, ok = memory.Read(ptr, memory.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
- if i := bytes.IndexByte(mem, 0); i < 0 {
+ if i := bytes.IndexByte(mem, 0); i < 0 {
panic(noNulErr)
} else {
return string(mem[:i])
diff --git a/.github/coverage.svg b/.github/coverage.svg
index f2b64de..f37c642 100644
--- a/.github/coverage.svg
+++ b/.github/coverage.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/conn.go b/conn.go
index 56ed013..3c900de 100644
--- a/conn.go
+++ b/conn.go
@@ -3,6 +3,7 @@ package sqlite3
import (
"bytes"
"context"
+ "math"
"strconv"
"github.com/tetratelabs/wazero"
@@ -208,7 +209,15 @@ func getString(memory api.Memory, ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
- mem, ok := memory.Read(ptr, maxlen+1)
+ switch maxlen {
+ case 0:
+ return ""
+ case math.MaxUint32:
+ //
+ default:
+ maxlen = maxlen + 1
+ }
+ mem, ok := memory.Read(ptr, maxlen)
if !ok {
mem, ok = memory.Read(ptr, memory.Size()-ptr)
if !ok {
diff --git a/conn_test.go b/conn_test.go
new file mode 100644
index 0000000..eb62896
--- /dev/null
+++ b/conn_test.go
@@ -0,0 +1,123 @@
+package sqlite3
+
+import (
+ "bytes"
+ "math"
+ "testing"
+)
+
+func TestConn_new(t *testing.T) {
+ db, err := Open(":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ defer func() { _ = recover() }()
+ db.new(math.MaxUint32)
+ t.Errorf("should have panicked")
+}
+
+func TestConn_newBytes(t *testing.T) {
+ db, err := Open(":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ptr := db.newBytes(nil)
+ if ptr != 0 {
+ t.Errorf("want nullptr got %x", ptr)
+ }
+
+ buf := []byte("sqlite3")
+ ptr = db.newBytes(buf)
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ want := buf
+ if got, ok := db.memory.Read(ptr, uint32(len(want))); !ok || !bytes.Equal(want, got) {
+ t.Errorf("want %q got %q", want, got)
+ }
+}
+
+func TestConn_newString(t *testing.T) {
+ db, err := Open(":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ptr := db.newString("")
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ str := "sqlite3\000sqlite3"
+ ptr = db.newString(str)
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ want := str + "\000"
+ if got, ok := db.memory.Read(ptr, uint32(len(want))); !ok || want != string(got) {
+ t.Errorf("want %q got %q", want, got)
+ }
+}
+
+func TestConn_getString(t *testing.T) {
+ db, err := Open(":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ptr := db.newString("")
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ str := "sqlite3" + "\000 drop this"
+ ptr = db.newString(str)
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ want := "sqlite3"
+ if got := db.getString(ptr, math.MaxUint32); want != got {
+ t.Errorf("want %q got %q", want, got)
+ }
+ if got := db.getString(ptr, 0); got != "" {
+ t.Errorf("want empty got %q", got)
+ }
+
+ func() {
+ defer func() { _ = recover() }()
+ db.getString(ptr, uint32(len(want)/2))
+ t.Errorf("should have panicked")
+ }()
+
+ func() {
+ defer func() { _ = recover() }()
+ db.getString(0, math.MaxUint32)
+ t.Errorf("should have panicked")
+ }()
+}
+
+func TestConn_free(t *testing.T) {
+ db, err := Open(":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ db.free(0)
+
+ ptr := db.new(0)
+ if ptr == 0 {
+ t.Errorf("want a pointer got nullptr")
+ }
+
+ db.free(ptr)
+}