From 921e1eafc2954b13d65da03ecdaf11e663b832d5 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 25 Jan 2023 14:59:02 +0000 Subject: [PATCH] Error handling. --- api.go | 3 ++ conn.go | 8 +++-- const.go | 4 +-- error.go | 1 + sqlite3/main.c | 2 +- vfs.go | 92 +++++++++++++++++++++++++++++++++++++------------- vfs_files.go | 33 ++++++++++-------- vfs_lock.go | 3 ++ vfs_unix.go | 6 ++++ vfs_windows.go | 2 ++ 10 files changed, 112 insertions(+), 42 deletions(-) diff --git a/api.go b/api.go index 31ff1cf..a3b51db 100644 --- a/api.go +++ b/api.go @@ -16,6 +16,9 @@ func newConn(module api.Module) *Conn { panic(noGlobalErr + "malloc_destructor") } destructor := uint32(global.Get()) + if destructor == 0 { + panic(noGlobalErr + "malloc_destructor") + } destructor, ok := module.Memory().ReadUint32Le(destructor) if !ok { panic(noGlobalErr + "malloc_destructor") diff --git a/conn.go b/conn.go index ed45ade..d61e0c1 100644 --- a/conn.go +++ b/conn.go @@ -162,10 +162,11 @@ func (c *Conn) new(len uint32) uint32 { if err != nil { panic(err) } - if r[0] == 0 { + ptr := uint32(r[0]) + if ptr == 0 || ptr >= c.memory.Size() { panic(oomErr) } - return uint32(r[0]) + return ptr } func (c *Conn) newBytes(s []byte) uint32 { @@ -204,6 +205,9 @@ func (c *Conn) getString(ptr, maxlen uint32) string { } func getString(memory api.Memory, ptr, maxlen uint32) string { + if ptr == 0 { + panic(nilErr) + } mem, ok := memory.Read(ptr, maxlen+1) if !ok { mem, ok = memory.Read(ptr, memory.Size()-ptr) diff --git a/const.go b/const.go index 92ef051..e16fbfe 100644 --- a/const.go +++ b/const.go @@ -121,7 +121,7 @@ const ( AUTH_USER = ExtendedErrorCode(AUTH) | (1 << 8) ) -type OpenFlag uint +type OpenFlag uint32 const ( OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */ @@ -148,7 +148,7 @@ const ( OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */ ) -type AccessFlag uint +type AccessFlag uint32 const ( ACCESS_EXISTS AccessFlag = 0 diff --git a/error.go b/error.go index accb8d0..24463b9 100644 --- a/error.go +++ b/error.go @@ -36,6 +36,7 @@ type errorString string func (e errorString) Error() string { return string(e) } const ( + nilErr = errorString("sqlite3: invalid memory address or null pointer dereference") oomErr = errorString("sqlite3: out of memory") rangeErr = errorString("sqlite3: index out of range") noNulErr = errorString("sqlite3: missing NUL terminator") diff --git a/sqlite3/main.c b/sqlite3/main.c index b2cf92f..9830701 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -73,7 +73,7 @@ static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName, .xDeviceCharacteristics = no_device_characteristics, }; int rc = go_open(vfs, zName, file, flags, pOutFlags); - file->pMethods = rc == SQLITE_OK ? &go_io : NULL; + file->pMethods = (char)rc == SQLITE_OK ? &go_io : NULL; return rc; } diff --git a/vfs.go b/vfs.go index a3a6582..3de1d10 100644 --- a/vfs.go +++ b/vfs.go @@ -62,6 +62,9 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin isdst = 1 } + if pTm == 0 { + panic(nilErr) + } // https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html if mem := mod.Memory(); true && mem.WriteUint32Le(pTm+0*wordSize, uint32(tm.Second())) && @@ -79,6 +82,9 @@ func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uin } func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 { + if zByte == 0 { + panic(nilErr) + } mem, ok := mod.Memory().Read(zByte, nByte) if !ok { panic(rangeErr) @@ -94,6 +100,9 @@ func vfsSleep(ctx context.Context, pVfs, nMicro uint32) uint32 { func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 { day := julianday.Float(time.Now()) + if prNow == 0 { + panic(nilErr) + } if ok := mod.Memory().WriteFloat64Le(prNow, day); !ok { panic(rangeErr) } @@ -103,6 +112,9 @@ func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uin func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 { day, nsec := julianday.Date(time.Now()) msec := day*86_400_000 + nsec/1_000_000 + if piNow == 0 { + panic(nilErr) + } if ok := mod.Memory().WriteUint64Le(piNow, uint64(msec)); !ok { panic(rangeErr) } @@ -122,7 +134,10 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull siz := uint32(len(abs) + 1) if siz > nFull { - return uint32(IOERR) + return uint32(CANTOPEN_FULLPATH) + } + if zFull == 0 { + panic(nilErr) } mem, ok := mod.Memory().Read(zFull, siz) if !ok { @@ -156,7 +171,7 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) return _OK } -func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath, flags, pResOut uint32) uint32 { +func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) uint32 { // Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ] // (as the Unix VFS does). @@ -164,7 +179,8 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath, flags, pResOut fi, err := os.Stat(path) var res uint32 - if flags == uint32(ACCESS_EXISTS) { + switch { + case flags == ACCESS_EXISTS: switch { case err == nil: res = 1 @@ -173,9 +189,10 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath, flags, pResOut default: return uint32(IOERR_ACCESS) } - } else if err == nil { + + case err == nil: var want fs.FileMode = syscall.S_IRUSR - if flags == uint32(ACCESS_READWRITE) { + if flags == ACCESS_READWRITE { want |= syscall.S_IWUSR } if fi.IsDir() { @@ -186,49 +203,68 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath, flags, pResOut } else { res = 0 } - } else if errors.Is(err, fs.ErrPermission) { + + case errors.Is(err, fs.ErrPermission): res = 0 - } else { + + default: return uint32(IOERR_ACCESS) } + if pResOut == 0 { + panic(nilErr) + } if ok := mod.Memory().WriteUint32Le(pResOut, res); !ok { panic(rangeErr) } return _OK } -func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile, flags, pOutFlags uint32) uint32 { - name := getString(mod.Memory(), zName, _MAX_PATHNAME) - +func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 { var oflags int - if OpenFlag(flags)&OPEN_EXCLUSIVE != 0 { + if flags&OPEN_EXCLUSIVE != 0 { oflags |= os.O_EXCL } - if OpenFlag(flags)&OPEN_CREATE != 0 { + if flags&OPEN_CREATE != 0 { oflags |= os.O_CREATE } - if OpenFlag(flags)&OPEN_READONLY != 0 { + if flags&OPEN_READONLY != 0 { oflags |= os.O_RDONLY } - if OpenFlag(flags)&OPEN_READWRITE != 0 { + if flags&OPEN_READWRITE != 0 { oflags |= os.O_RDWR } - file, err := os.OpenFile(name, oflags, 0600) + + var err error + var file *os.File + if zName == 0 { + file, err = os.CreateTemp("", "*.db") + } else { + name := getString(mod.Memory(), zName, _MAX_PATHNAME) + file, err = os.OpenFile(name, oflags, 0600) + } if err != nil { return uint32(CANTOPEN) } - id, err := vfsGetOpenFileID(file) + if flags&OPEN_DELETEONCLOSE != 0 { + deleteOnClose(file) + } + + info, err := file.Stat() if err != nil { return uint32(CANTOPEN) } + if info.IsDir() { + return uint32(CANTOPEN_ISDIR) + } + id := vfsGetOpenFileID(file, info) vfsFilePtr{mod, pFile}.SetID(id).SetLock(_NO_LOCK) if pOutFlags == 0 { return _OK } - if ok := mod.Memory().WriteUint32Le(pOutFlags, flags); !ok { + if ok := mod.Memory().WriteUint32Le(pOutFlags, uint32(flags)); !ok { panic(rangeErr) } return _OK @@ -244,33 +280,39 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { } func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { - mem, ok := mod.Memory().Read(zBuf, iAmt) + if zBuf == 0 { + panic(nilErr) + } + buf, ok := mod.Memory().Read(zBuf, iAmt) if !ok { panic(rangeErr) } file := vfsFilePtr{mod, pFile}.OSFile() - n, err := file.ReadAt(mem, int64(iOfst)) + n, err := file.ReadAt(buf, int64(iOfst)) if n == int(iAmt) { return _OK } if n == 0 && err != io.EOF { return uint32(IOERR_READ) } - for i := range mem[n:] { - mem[i] = 0 + for i := range buf[n:] { + buf[i] = 0 } return uint32(IOERR_SHORT_READ) } func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { - mem, ok := mod.Memory().Read(zBuf, iAmt) + if zBuf == 0 { + panic(nilErr) + } + buf, ok := mod.Memory().Read(zBuf, iAmt) if !ok { panic(rangeErr) } file := vfsFilePtr{mod, pFile}.OSFile() - _, err := file.WriteAt(mem, int64(iOfst)) + _, err := file.WriteAt(buf, int64(iOfst)) if err != nil { return uint32(IOERR_WRITE) } @@ -298,12 +340,16 @@ func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 { func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 { // This uses [file.Seek] because we don't care about the offset for reading/writing. // But consider using [file.Stat] instead (as other VFSes do). + file := vfsFilePtr{mod, pFile}.OSFile() off, err := file.Seek(0, io.SeekEnd) if err != nil { return uint32(IOERR_SEEK) } + if pSize == 0 { + panic(nilErr) + } if ok := mod.Memory().WriteUint64Le(pSize, uint64(off)); !ok { panic(rangeErr) } diff --git a/vfs_files.go b/vfs_files.go index 05891fa..3d3b748 100644 --- a/vfs_files.go +++ b/vfs_files.go @@ -21,12 +21,7 @@ var ( vfsOpenFilesMtx sync.Mutex ) -func vfsGetOpenFileID(file *os.File) (uint32, error) { - fi, err := file.Stat() - if err != nil { - return 0, err - } - +func vfsGetOpenFileID(file *os.File, info os.FileInfo) uint32 { vfsOpenFilesMtx.Lock() defer vfsOpenFilesMtx.Unlock() @@ -35,18 +30,16 @@ func vfsGetOpenFileID(file *os.File) (uint32, error) { if of == nil { continue } - if os.SameFile(fi, of.info) { - if err := file.Close(); err != nil { - return 0, err - } + if os.SameFile(info, of.info) { of.nref++ - return uint32(id), nil + _ = file.Close() + return uint32(id) } } of := &vfsOpenFile{ file: file, - info: fi, + info: info, nref: 1, vfsLocker: &vfsNoopLocker{}, @@ -56,14 +49,14 @@ func vfsGetOpenFileID(file *os.File) (uint32, error) { for id, ptr := range vfsOpenFiles { if ptr == nil { vfsOpenFiles[id] = of - return uint32(id), nil + return uint32(id) } } // Add a new slot. id := len(vfsOpenFiles) vfsOpenFiles = append(vfsOpenFiles, of) - return uint32(id), nil + return uint32(id) } func vfsReleaseOpenFile(id uint32) error { @@ -92,6 +85,9 @@ func (p vfsFilePtr) OSFile() *os.File { } func (p vfsFilePtr) ID() uint32 { + if p.ptr == 0 { + panic(nilErr) + } id, ok := p.Memory().ReadUint32Le(p.ptr + wordSize) if !ok { panic(rangeErr) @@ -100,6 +96,9 @@ func (p vfsFilePtr) ID() uint32 { } func (p vfsFilePtr) Lock() vfsLockState { + if p.ptr == 0 { + panic(nilErr) + } lk, ok := p.Memory().ReadUint32Le(p.ptr + 2*wordSize) if !ok { panic(rangeErr) @@ -108,6 +107,9 @@ func (p vfsFilePtr) Lock() vfsLockState { } func (p vfsFilePtr) SetID(id uint32) vfsFilePtr { + if p.ptr == 0 { + panic(nilErr) + } if ok := p.Memory().WriteUint32Le(p.ptr+wordSize, id); !ok { panic(rangeErr) } @@ -115,6 +117,9 @@ func (p vfsFilePtr) SetID(id uint32) vfsFilePtr { } func (p vfsFilePtr) SetLock(lock vfsLockState) vfsFilePtr { + if p.ptr == 0 { + panic(nilErr) + } if ok := p.Memory().WriteUint32Le(p.ptr+2*wordSize, uint32(lock)); !ok { panic(rangeErr) } diff --git a/vfs_lock.go b/vfs_lock.go index 6954ae8..3fb7641 100644 --- a/vfs_lock.go +++ b/vfs_lock.go @@ -241,6 +241,9 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ui if locked { res = 1 } + if pResOut == 0 { + panic(nilErr) + } if ok := mod.Memory().WriteUint32Le(pResOut, res); !ok { panic(rangeErr) } diff --git a/vfs_unix.go b/vfs_unix.go index 5348d13..7cff772 100644 --- a/vfs_unix.go +++ b/vfs_unix.go @@ -1,3 +1,9 @@ //go:build unix package sqlite3 + +import "os" + +func deleteOnClose(f *os.File) { + _ = os.Remove(f.Name()) +} diff --git a/vfs_windows.go b/vfs_windows.go index 4f24317..ca5d3f4 100644 --- a/vfs_windows.go +++ b/vfs_windows.go @@ -1 +1,3 @@ package sqlite3 + +func deleteOnClose(f *os.File) {}