From e36b2698c96c0292912cd1391061342f6c93954b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 29 Jan 2023 02:11:41 +0000 Subject: [PATCH] Tests. --- conn.go | 16 ++------- conn_test.go | 4 +-- mem.go | 64 ++++++++++++++++------------------- stmt.go | 4 +-- vfs.go | 10 +++--- vfs_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++---------- 6 files changed, 120 insertions(+), 74 deletions(-) diff --git a/conn.go b/conn.go index 298b2f2..d0f50f6 100644 --- a/conn.go +++ b/conn.go @@ -152,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 { panic(err) } ptr := uint32(r[0]) - if ptr == 0 || ptr >= c.mem.size() { + if ptr == 0 { panic(oomErr) } return ptr @@ -165,12 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 { siz := uint32(len(b)) ptr := c.new(siz) - buf, ok := c.mem.read(ptr, siz) - if !ok { - c.api.free.Call(c.ctx, uint64(ptr)) - panic(rangeErr) - } - + buf := c.mem.view(ptr, siz) copy(buf, b) return ptr } @@ -178,12 +173,7 @@ func (c *Conn) newBytes(b []byte) uint32 { func (c *Conn) newString(s string) uint32 { siz := uint32(len(s) + 1) ptr := c.new(siz) - buf, ok := c.mem.read(ptr, siz) - if !ok { - c.api.free.Call(c.ctx, uint64(ptr)) - panic(rangeErr) - } - + buf := c.mem.view(ptr, siz) buf[len(s)] = 0 copy(buf, s) return ptr diff --git a/conn_test.go b/conn_test.go index 76b71a7..1512196 100644 --- a/conn_test.go +++ b/conn_test.go @@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) { } want := buf - if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) { + if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } } @@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) { } want := str + "\000" - if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want { + if got := db.mem.view(ptr, uint32(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } diff --git a/mem.go b/mem.go index 43dad72..3a7f115 100644 --- a/mem.go +++ b/mem.go @@ -11,87 +11,86 @@ type memory struct { mod api.Module } -func (m memory) size() uint32 { - return m.mod.Memory().Size() -} - -func (m memory) read(offset, byteCount uint32) ([]byte, bool) { - if offset == 0 { +func (m memory) view(ptr, size uint32) []byte { + if ptr == 0 { panic(nilErr) } - return m.mod.Memory().Read(offset, byteCount) -} - -func (m memory) mustRead(offset, byteCount uint32) []byte { - buf, ok := m.read(offset, byteCount) + buf, ok := m.mod.Memory().Read(ptr, size) if !ok { panic(rangeErr) } return buf } -func (m memory) readUint32(offset uint32) uint32 { - if offset == 0 { +func (m memory) readUint32(ptr uint32) uint32 { + if ptr == 0 { panic(nilErr) } - v, ok := m.mod.Memory().ReadUint32Le(offset) + v, ok := m.mod.Memory().ReadUint32Le(ptr) if !ok { panic(rangeErr) } return v } -func (m memory) writeUint32(offset, v uint32) { - if offset == 0 { +func (m memory) writeUint32(ptr, v uint32) { + if ptr == 0 { panic(nilErr) } - ok := m.mod.Memory().WriteUint32Le(offset, v) + ok := m.mod.Memory().WriteUint32Le(ptr, v) if !ok { panic(rangeErr) } } -func (m memory) readUint64(offset uint32) uint64 { - if offset == 0 { +func (m memory) readUint64(ptr uint32) uint64 { + if ptr == 0 { panic(nilErr) } - v, ok := m.mod.Memory().ReadUint64Le(offset) + v, ok := m.mod.Memory().ReadUint64Le(ptr) if !ok { panic(rangeErr) } return v } -func (m memory) writeUint64(offset uint32, v uint64) { - if offset == 0 { +func (m memory) writeUint64(ptr uint32, v uint64) { + if ptr == 0 { panic(nilErr) } - ok := m.mod.Memory().WriteUint64Le(offset, v) + ok := m.mod.Memory().WriteUint64Le(ptr, v) if !ok { panic(rangeErr) } } -func (m memory) readFloat64(offset uint32) float64 { - return math.Float64frombits(m.readUint64(offset)) +func (m memory) readFloat64(ptr uint32) float64 { + return math.Float64frombits(m.readUint64(ptr)) } -func (m memory) writeFloat64(offset uint32, v float64) { - m.writeUint64(offset, math.Float64bits(v)) +func (m memory) writeFloat64(ptr uint32, v float64) { + m.writeUint64(ptr, math.Float64bits(v)) } func (m memory) readString(ptr, maxlen uint32) string { + if ptr == 0 { + panic(nilErr) + } switch maxlen { case 0: return "" case math.MaxUint32: - // + // avoid overflow default: maxlen = maxlen + 1 } - buf, ok := m.read(ptr, maxlen) + mem := m.mod.Memory() + buf, ok := mem.Read(ptr, maxlen) if !ok { - buf = m.mustRead(ptr, m.size()-ptr) + buf, ok = mem.Read(ptr, mem.Size()-ptr) + if !ok { + panic(rangeErr) + } } if i := bytes.IndexByte(buf, 0); i < 0 { panic(noNulErr) @@ -102,10 +101,7 @@ func (m memory) readString(ptr, maxlen uint32) string { func (m memory) writeString(ptr uint32, s string) { siz := uint32(len(s) + 1) - buf, ok := m.read(ptr, siz) - if !ok { - panic(rangeErr) - } + buf := m.view(ptr, siz) buf[len(s)] = 0 copy(buf, s) } diff --git a/stmt.go b/stmt.go index b6c8dcc..4c9aff1 100644 --- a/stmt.go +++ b/stmt.go @@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string { panic(err) } - mem := s.c.mem.mustRead(ptr, uint32(r[0])) + mem := s.c.mem.view(ptr, uint32(r[0])) return string(mem) } @@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { panic(err) } - mem := s.c.mem.mustRead(ptr, uint32(r[0])) + mem := s.c.mem.view(ptr, uint32(r[0])) return append(buf[0:0], mem...) } diff --git a/vfs.go b/vfs.go index 6bed8bc..b482c54 100644 --- a/vfs.go +++ b/vfs.go @@ -78,7 +78,7 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin } func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 { - mem := memory{mod}.mustRead(zByte, nByte) + mem := memory{mod}.view(zByte, nByte) n, _ := rand.Read(mem) return uint32(n) } @@ -116,7 +116,7 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull if siz > nFull { return uint32(CANTOPEN_FULLPATH) } - mem := memory{mod}.mustRead(zFull, siz) + mem := memory{mod}.view(zFull, siz) mem[len(abs)] = 0 copy(mem, abs) @@ -246,7 +246,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { } func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { - buf := memory{mod}.mustRead(zBuf, iAmt) + buf := memory{mod}.view(zBuf, iAmt) file := vfsFilePtr{mod, pFile}.OSFile() n, err := file.ReadAt(buf, int64(iOfst)) @@ -257,13 +257,13 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs return uint32(IOERR_READ) } for i := range buf[n:] { - buf[i] = 0 + buf[n+i] = 0 } return uint32(IOERR_SHORT_READ) } func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { - buf := memory{mod}.mustRead(zBuf, iAmt) + buf := memory{mod}.view(zBuf, iAmt) file := vfsFilePtr{mod, pFile}.OSFile() _, err := file.WriteAt(buf, int64(iOfst)) diff --git a/vfs_test.go b/vfs_test.go index fc5ea33..1e24369 100644 --- a/vfs_test.go +++ b/vfs_test.go @@ -69,7 +69,7 @@ func Test_vfsRandomness(t *testing.T) { rand.Seed(0) rand.Read(want[:]) - if got := mem.mustRead(4, 16); !bytes.Equal(got, want[:]) { + if got := mem.view(4, 16); !bytes.Equal(got, want[:]) { t.Errorf("got %q, want %q", got, want) } } @@ -121,7 +121,7 @@ func Test_vfsCurrentTime64(t *testing.T) { } func Test_vfsFullPathname(t *testing.T) { - mem := newMemory(128) + mem := newMemory(128 + _MAX_PATHNAME) mem.writeString(4, ".") rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8) @@ -141,10 +141,6 @@ func Test_vfsFullPathname(t *testing.T) { } func Test_vfsDelete(t *testing.T) { - memory := make(mockMemory, 128+_MAX_PATHNAME) - module := &mockModule{&memory} - - os.CreateTemp("", "sqlite3") file, err := os.CreateTemp("", "sqlite3-") if err != nil { t.Fatal(err) @@ -153,9 +149,10 @@ func Test_vfsDelete(t *testing.T) { defer os.RemoveAll(name) file.Close() - memory.Write(4, []byte(name)) + mem := newMemory(128 + _MAX_PATHNAME) + mem.writeString(4, name) - rc := vfsDelete(context.TODO(), module, 0, 4, 1) + rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1) if rc != _OK { t.Fatal("returned", rc) } @@ -164,38 +161,101 @@ func Test_vfsDelete(t *testing.T) { t.Fatal("did not delete the file") } - rc = vfsDelete(context.TODO(), module, 0, 4, 1) + rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1) if rc != _OK { t.Fatal("returned", rc) } } func Test_vfsAccess(t *testing.T) { - memory := make(mockMemory, 128+_MAX_PATHNAME) - module := &mockModule{&memory} - - os.CreateTemp("", "sqlite3") dir, err := os.MkdirTemp("", "sqlite3-") if err != nil { t.Fatal(err) } defer os.RemoveAll(dir) - memory.Write(8, []byte(dir)) + mem := newMemory(128 + _MAX_PATHNAME) + mem.writeString(8, dir) - rc := vfsAccess(context.TODO(), module, 0, 8, ACCESS_EXISTS, 4) + rc := vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_EXISTS, 4) if rc != _OK { t.Fatal("returned", rc) } - if got, ok := memory.ReadByte(4); !ok && got != 1 { + if got := mem.readUint32(4); got != 1 { t.Error("directory did not exist") } - rc = vfsAccess(context.TODO(), module, 0, 8, ACCESS_READWRITE, 4) + rc = vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_READWRITE, 4) if rc != _OK { t.Fatal("returned", rc) } - if got, ok := memory.ReadByte(4); !ok && got != 1 { + if got := mem.readUint32(4); got != 1 { t.Error("can't access directory") } } + +func Test_vfsFile(t *testing.T) { + mem := newMemory(128) + + // Open a temporary file. + rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Write stuff. + text := "Hello world!" + mem.writeString(16, text) + rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Check file size. + rc = vfsFileSize(context.TODO(), mem.mod, 4, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := mem.readUint32(16); got != uint32(len(text)) { + t.Errorf("got %d", got) + } + + // Partial read at offset. + rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4) + if rc != uint32(IOERR_SHORT_READ) { + t.Fatal("returned", rc) + } + if got := mem.readString(16, 64); got != text[4:] { + t.Errorf("got %q", got) + } + + // Truncate the file. + rc = vfsTruncate(context.TODO(), mem.mod, 4, 4) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Check file size. + rc = vfsFileSize(context.TODO(), mem.mod, 4, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := mem.readUint32(16); got != 4 { + t.Errorf("got %d", got) + } + + // Read at offset. + rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := mem.readString(32, 64); got != text[:4] { + t.Errorf("got %q", got) + } + + // Close the file. + rc = vfsClose(context.TODO(), mem.mod, 4) + if rc != _OK { + t.Fatal("returned", rc) + } +}