From 5653efa70efb2f843c56ef9aaf3217f2e6df43d4 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 24 Nov 2023 17:25:02 +0000 Subject: [PATCH] Limits. --- conn.go | 3 +++ const.go | 5 ++-- error_test.go | 4 ++-- sqlite.go | 4 ++-- stmt.go | 13 ++++++++-- tests/stmt_test.go | 53 +++++++++++++++++++++++++++++++++++++++++ vfs/const.go | 2 +- vfs/readervfs/reader.go | 4 ---- vfs/vfs.go | 8 +++---- vtab.go | 16 ++++++------- 10 files changed, 87 insertions(+), 25 deletions(-) diff --git a/conn.go b/conn.go index fe8295c..12e9e94 100644 --- a/conn.go +++ b/conn.go @@ -170,6 +170,9 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { // // https://sqlite.org/c3ref/prepare.html func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { + if len(sql) > _MAX_LENGTH { + return nil, "", TOOBIG + } if emptyStatement(sql) { return nil, "", nil } diff --git a/const.go b/const.go index 510e25e..5697b31 100644 --- a/const.go +++ b/const.go @@ -9,8 +9,9 @@ const ( _UTF8 = 1 - _MAX_STRING = 512 // Used for short strings: names, error messages… - + _MAX_NAME = 512 // Used for short strings: names, error messages… + _MAX_LENGTH = 1e9 + _MAX_SQL_LENGTH = 1e9 _MAX_ALLOCATION_SIZE = 0x7ffffeff ptrlen = 4 diff --git a/error_test.go b/error_test.go index d4b9037..f4df1e7 100644 --- a/error_test.go +++ b/error_test.go @@ -136,7 +136,7 @@ func Test_ErrorCode_Error(t *testing.T) { for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " r := db.call(db.api.errstr, uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_STRING) + want += util.ReadString(db.mod, uint32(r), _MAX_NAME) got := ErrorCode(i).Error() if got != want { @@ -158,7 +158,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " r := db.call(db.api.errstr, uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_STRING) + want += util.ReadString(db.mod, uint32(r), _MAX_NAME) got := ExtendedErrorCode(i).Error() if got != want { diff --git a/sqlite.go b/sqlite.go index 4ee8400..2c3dbe2 100644 --- a/sqlite.go +++ b/sqlite.go @@ -210,12 +210,12 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { } if r := sqlt.call(sqlt.api.errstr, rc); r != 0 { - err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) + err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) } if handle != 0 { if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 { - err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) + err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) } if sql != nil { diff --git a/stmt.go b/stmt.go index 98a2728..087497c 100644 --- a/stmt.go +++ b/stmt.go @@ -123,7 +123,7 @@ func (s *Stmt) BindName(param int) string { if ptr == 0 { return "" } - return util.ReadString(s.c.mod, ptr, _MAX_STRING) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // BindBool binds a bool to the prepared statement. @@ -173,6 +173,9 @@ func (s *Stmt) BindFloat(param int, value float64) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindText(param int, value string) error { + if len(value) > _MAX_LENGTH { + return TOOBIG + } ptr := s.c.newString(value) r := s.c.call(s.c.api.bindText, uint64(s.handle), uint64(param), @@ -186,6 +189,9 @@ func (s *Stmt) BindText(param int, value string) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindRawText(param int, value []byte) error { + if len(value) > _MAX_LENGTH { + return TOOBIG + } ptr := s.c.newBytes(value) r := s.c.call(s.c.api.bindText, uint64(s.handle), uint64(param), @@ -200,6 +206,9 @@ func (s *Stmt) BindRawText(param int, value []byte) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindBlob(param int, value []byte) error { + if len(value) > _MAX_LENGTH { + return TOOBIG + } ptr := s.c.newBytes(value) r := s.c.call(s.c.api.bindBlob, uint64(s.handle), uint64(param), @@ -309,7 +318,7 @@ func (s *Stmt) ColumnName(col int) string { if ptr == 0 { panic(util.OOMErr) } - return util.ReadString(s.c.mod, ptr, _MAX_STRING) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnType returns the initial [Datatype] of the result column. diff --git a/tests/stmt_test.go b/tests/stmt_test.go index a44abd1..c475f23 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -587,3 +587,56 @@ func TestStmt_ColumnTime(t *testing.T) { } } } + +func TestStmt_Error(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var blob [1e9 + 1]byte + + _, _, err = db.Prepare(string(blob[:])) + if err == nil { + t.Errorf("want error") + } else { + t.Log(err) + } + + stmt, _, err := db.Prepare(`SELECT ?`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + err = stmt.BindText(1, string(blob[:])) + if err == nil { + t.Errorf("want error") + } else { + t.Log(err) + } + + err = stmt.BindBlob(1, blob[:]) + if err == nil { + t.Errorf("want error") + } else { + t.Log(err) + } + + err = stmt.BindRawText(1, blob[:]) + if err == nil { + t.Errorf("want error") + } else { + t.Log(err) + } + + err = stmt.BindZeroBlob(1, 1e9+1) + if err == nil { + t.Errorf("want error") + } else { + t.Log(err) + } +} diff --git a/vfs/const.go b/vfs/const.go index 49f2a5e..02e8f06 100644 --- a/vfs/const.go +++ b/vfs/const.go @@ -3,7 +3,7 @@ package vfs import "github.com/ncruces/go-sqlite3/internal/util" const ( - _MAX_STRING = 512 // Used for short strings: names, error messages… + _MAX_NAME = 512 // Used for short strings: names, error messages… _MAX_PATHNAME = 512 _DEFAULT_SECTOR_SIZE = 4096 ) diff --git a/vfs/readervfs/reader.go b/vfs/readervfs/reader.go index 53b8ba0..3e47eb6 100644 --- a/vfs/readervfs/reader.go +++ b/vfs/readervfs/reader.go @@ -7,7 +7,6 @@ import ( type readerVFS struct{} -// Open implements the [vfs.VFS] interface. func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) { if flags&vfs.OPEN_MAIN_DB == 0 { return nil, flags, sqlite3.CANTOPEN @@ -20,17 +19,14 @@ func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, return nil, flags, sqlite3.CANTOPEN } -// Delete implements the [vfs.VFS] interface. func (readerVFS) Delete(name string, dirSync bool) error { return sqlite3.IOERR_DELETE } -// Access implements the [vfs.VFS] interface. func (readerVFS) Access(name string, flag vfs.AccessFlag) (bool, error) { return false, nil } -// FullPathname implements the [vfs.VFS] interface. func (readerVFS) FullPathname(name string) (string, error) { return name, nil } diff --git a/vfs/vfs.go b/vfs/vfs.go index d945b67..51bcb5b 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -45,7 +45,7 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder } func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { - name := util.ReadString(mod, zVfsName, _MAX_STRING) + name := util.ReadString(mod, zVfsName, _MAX_NAME) if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) { return 1 } @@ -371,7 +371,7 @@ func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags O if stack[0] == 0 { return params } - key := util.ReadString(mod, uint32(stack[0]), _MAX_STRING) + key := util.ReadString(mod, uint32(stack[0]), _MAX_NAME) if params.Has(key) { continue } @@ -384,7 +384,7 @@ func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags O if params == nil { params = url.Values{} } - params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_STRING)) + params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_NAME)) } } @@ -392,7 +392,7 @@ func vfsGet(mod api.Module, pVfs uint32) VFS { var name string if pVfs != 0 { const zNameOffset = 16 - name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_STRING) + name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_NAME) } if vfs := Find(name); vfs != nil { return vfs diff --git a/vtab.go b/vtab.go index a242fb0..7274221 100644 --- a/vtab.go +++ b/vtab.go @@ -70,7 +70,7 @@ func implements[T any](typ reflect.Type) bool { // // https://sqlite.org/c3ref/declare_vtab.html func (c *Conn) DeclareVtab(sql string) error { - // defer c.arena.reset() + // The arena will be cleared by the prepare or exec method. sqlPtr := c.arena.string(sql) r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr)) return c.error(r) @@ -255,7 +255,7 @@ type IndexConstraintUsage struct { // // https://sqlite.org/c3ref/vtab_rhs_value.html func (idx *IndexInfo) RHSValue(column int) (*Value, error) { - // defer idx.c.arena.reset() + // The arena will be cleared by the prepare or exec method. valPtr := idx.c.arena.new(ptrlen) r := idx.c.call(idx.c.api.vtabRHSValue, uint64(idx.handle), uint64(column), uint64(valPtr)) @@ -369,7 +369,7 @@ func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _, for i := uint32(0); i < argc; i++ { ptr := util.ReadUint32(mod, argv+i*ptrlen) - arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING)) + arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH)) } module := vtabGetHandle(ctx, mod, pMod) @@ -425,13 +425,13 @@ func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, argc, argv, func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew uint32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer) - err := vtab.Rename(util.ReadString(mod, zNew, _MAX_STRING)) + err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab, nArg, zName, pxFunc uint32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader) - fn, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_STRING)) + fn, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME)) if fn != nil { handle := util.AddHandle(ctx, fn) util.WriteUint32(mod, pxFunc, handle) @@ -444,8 +444,8 @@ func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab, nArg, zNam func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker) - schema := util.ReadString(mod, zSchema, _MAX_STRING) - table := util.ReadString(mod, zTabName, _MAX_STRING) + schema := util.ReadString(mod, zSchema, _MAX_NAME) + table := util.ReadString(mod, zTabName, _MAX_NAME) err := vtab.Integrity(schema, table, int(mFlags)) // xIntegrity should return OK - even if it finds problems in the content of the virtual table. // https://sqlite.org/vtab.html#xintegrity @@ -518,7 +518,7 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idx args := callbackArgs(db, argc, argv) var idxName string if idxStr != 0 { - idxName = util.ReadString(mod, idxStr, _MAX_STRING) + idxName = util.ReadString(mod, idxStr, _MAX_NAME) } err := cursor.Filter(int(idxNum), idxName, args...) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)