This commit is contained in:
Nuno Cruces
2023-01-29 02:11:41 +00:00
parent 0ace464670
commit e36b2698c9
6 changed files with 120 additions and 74 deletions

16
conn.go
View File

@@ -152,7 +152,7 @@ func (c *Conn) new(len uint32) uint32 {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 || ptr >= c.mem.size() {
if ptr == 0 {
panic(oomErr)
}
return ptr
@@ -165,12 +165,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
siz := uint32(len(b))
ptr := c.new(siz)
buf, ok := c.mem.read(ptr, siz)
if !ok {
c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr)
}
buf := c.mem.view(ptr, siz)
copy(buf, b)
return ptr
}
@@ -178,12 +173,7 @@ func (c *Conn) newBytes(b []byte) uint32 {
func (c *Conn) newString(s string) uint32 {
siz := uint32(len(s) + 1)
ptr := c.new(siz)
buf, ok := c.mem.read(ptr, siz)
if !ok {
c.api.free.Call(c.ctx, uint64(ptr))
panic(rangeErr)
}
buf := c.mem.view(ptr, siz)
buf[len(s)] = 0
copy(buf, s)
return ptr

View File

@@ -37,7 +37,7 @@ func TestConn_newBytes(t *testing.T) {
}
want := buf
if got := db.mem.mustRead(ptr, uint32(len(want))); !bytes.Equal(got, want) {
if got := db.mem.view(ptr, uint32(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -61,7 +61,7 @@ func TestConn_newString(t *testing.T) {
}
want := str + "\000"
if got := db.mem.mustRead(ptr, uint32(len(want))); string(got) != want {
if got := db.mem.view(ptr, uint32(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}

64
mem.go
View File

@@ -11,87 +11,86 @@ type memory struct {
mod api.Module
}
func (m memory) size() uint32 {
return m.mod.Memory().Size()
}
func (m memory) read(offset, byteCount uint32) ([]byte, bool) {
if offset == 0 {
func (m memory) view(ptr, size uint32) []byte {
if ptr == 0 {
panic(nilErr)
}
return m.mod.Memory().Read(offset, byteCount)
}
func (m memory) mustRead(offset, byteCount uint32) []byte {
buf, ok := m.read(offset, byteCount)
buf, ok := m.mod.Memory().Read(ptr, size)
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint32(offset uint32) uint32 {
if offset == 0 {
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint32Le(offset)
v, ok := m.mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint32(offset, v uint32) {
if offset == 0 {
func (m memory) writeUint32(ptr, v uint32) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint32Le(offset, v)
ok := m.mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint64(offset uint32) uint64 {
if offset == 0 {
func (m memory) readUint64(ptr uint32) uint64 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint64Le(offset)
v, ok := m.mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint64(offset uint32, v uint64) {
if offset == 0 {
func (m memory) writeUint64(ptr uint32, v uint64) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint64Le(offset, v)
ok := m.mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readFloat64(offset uint32) float64 {
return math.Float64frombits(m.readUint64(offset))
func (m memory) readFloat64(ptr uint32) float64 {
return math.Float64frombits(m.readUint64(ptr))
}
func (m memory) writeFloat64(offset uint32, v float64) {
m.writeUint64(offset, math.Float64bits(v))
func (m memory) writeFloat64(ptr uint32, v float64) {
m.writeUint64(ptr, math.Float64bits(v))
}
func (m memory) readString(ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
//
// avoid overflow
default:
maxlen = maxlen + 1
}
buf, ok := m.read(ptr, maxlen)
mem := m.mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf = m.mustRead(ptr, m.size()-ptr)
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(noNulErr)
@@ -102,10 +101,7 @@ func (m memory) readString(ptr, maxlen uint32) string {
func (m memory) writeString(ptr uint32, s string) {
siz := uint32(len(s) + 1)
buf, ok := m.read(ptr, siz)
if !ok {
panic(rangeErr)
}
buf := m.view(ptr, siz)
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -163,7 +163,7 @@ func (s *Stmt) ColumnText(col int) string {
panic(err)
}
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, uint32(r[0]))
return string(mem)
}
@@ -190,6 +190,6 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
panic(err)
}
mem := s.c.mem.mustRead(ptr, uint32(r[0]))
mem := s.c.mem.view(ptr, uint32(r[0]))
return append(buf[0:0], mem...)
}

10
vfs.go
View File

@@ -78,7 +78,7 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.mustRead(zByte, nByte)
mem := memory{mod}.view(zByte, nByte)
n, _ := rand.Read(mem)
return uint32(n)
}
@@ -116,7 +116,7 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull
if siz > nFull {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.mustRead(zFull, siz)
mem := memory{mod}.view(zFull, siz)
mem[len(abs)] = 0
copy(mem, abs)
@@ -246,7 +246,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.mustRead(zBuf, iAmt)
buf := memory{mod}.view(zBuf, iAmt)
file := vfsFilePtr{mod, pFile}.OSFile()
n, err := file.ReadAt(buf, int64(iOfst))
@@ -257,13 +257,13 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
return uint32(IOERR_READ)
}
for i := range buf[n:] {
buf[i] = 0
buf[n+i] = 0
}
return uint32(IOERR_SHORT_READ)
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.mustRead(zBuf, iAmt)
buf := memory{mod}.view(zBuf, iAmt)
file := vfsFilePtr{mod, pFile}.OSFile()
_, err := file.WriteAt(buf, int64(iOfst))

View File

@@ -69,7 +69,7 @@ func Test_vfsRandomness(t *testing.T) {
rand.Seed(0)
rand.Read(want[:])
if got := mem.mustRead(4, 16); !bytes.Equal(got, want[:]) {
if got := mem.view(4, 16); !bytes.Equal(got, want[:]) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -121,7 +121,7 @@ func Test_vfsCurrentTime64(t *testing.T) {
}
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128)
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
rc := vfsFullPathname(context.TODO(), mem.mod, 0, 4, 0, 8)
@@ -141,10 +141,6 @@ func Test_vfsFullPathname(t *testing.T) {
}
func Test_vfsDelete(t *testing.T) {
memory := make(mockMemory, 128+_MAX_PATHNAME)
module := &mockModule{&memory}
os.CreateTemp("", "sqlite3")
file, err := os.CreateTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
@@ -153,9 +149,10 @@ func Test_vfsDelete(t *testing.T) {
defer os.RemoveAll(name)
file.Close()
memory.Write(4, []byte(name))
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
rc := vfsDelete(context.TODO(), module, 0, 4, 1)
rc := vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -164,38 +161,101 @@ func Test_vfsDelete(t *testing.T) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(context.TODO(), module, 0, 4, 1)
rc = vfsDelete(context.TODO(), mem.mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
}
func Test_vfsAccess(t *testing.T) {
memory := make(mockMemory, 128+_MAX_PATHNAME)
module := &mockModule{&memory}
os.CreateTemp("", "sqlite3")
dir, err := os.MkdirTemp("", "sqlite3-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
memory.Write(8, []byte(dir))
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
rc := vfsAccess(context.TODO(), module, 0, 8, ACCESS_EXISTS, 4)
rc := vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got, ok := memory.ReadByte(4); !ok && got != 1 {
if got := mem.readUint32(4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(context.TODO(), module, 0, 8, ACCESS_READWRITE, 4)
rc = vfsAccess(context.TODO(), mem.mod, 0, 8, ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got, ok := memory.ReadByte(4); !ok && got != 1 {
if got := mem.readUint32(4); got != 1 {
t.Error("can't access directory")
}
}
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
// Open a temporary file.
rc := vfsOpen(context.TODO(), mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
t.Fatal("returned", rc)
}
if got := mem.readString(16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(context.TODO(), mem.mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(context.TODO(), mem.mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(context.TODO(), mem.mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readString(32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(context.TODO(), mem.mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
}