This commit is contained in:
Nuno Cruces
2023-11-24 17:25:02 +00:00
parent 1acb95917a
commit 5653efa70e
10 changed files with 87 additions and 25 deletions

View File

@@ -170,6 +170,9 @@ func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
//
// https://sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
if len(sql) > _MAX_LENGTH {
return nil, "", TOOBIG
}
if emptyStatement(sql) {
return nil, "", nil
}

View File

@@ -9,8 +9,9 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_NAME = 512 // Used for short strings: names, error messages…
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_ALLOCATION_SIZE = 0x7ffffeff
ptrlen = 4

View File

@@ -136,7 +136,7 @@ func Test_ErrorCode_Error(t *testing.T) {
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := ErrorCode(i).Error()
if got != want {
@@ -158,7 +158,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
want += util.ReadString(db.mod, uint32(r), _MAX_NAME)
got := ExtendedErrorCode(i).Error()
if got != want {

View File

@@ -210,12 +210,12 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
}
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
}
if sql != nil {

13
stmt.go
View File

@@ -123,7 +123,7 @@ func (s *Stmt) BindName(param int) string {
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// BindBool binds a bool to the prepared statement.
@@ -173,6 +173,9 @@ func (s *Stmt) BindFloat(param int, value float64) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newString(value)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
@@ -186,6 +189,9 @@ func (s *Stmt) BindText(param int, value string) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindRawText(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
@@ -200,6 +206,9 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newBytes(value)
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
@@ -309,7 +318,7 @@ func (s *Stmt) ColumnName(col int) string {
if ptr == 0 {
panic(util.OOMErr)
}
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnType returns the initial [Datatype] of the result column.

View File

@@ -587,3 +587,56 @@ func TestStmt_ColumnTime(t *testing.T) {
}
}
}
func TestStmt_Error(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var blob [1e9 + 1]byte
_, _, err = db.Prepare(string(blob[:]))
if err == nil {
t.Errorf("want error")
} else {
t.Log(err)
}
stmt, _, err := db.Prepare(`SELECT ?`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = stmt.BindText(1, string(blob[:]))
if err == nil {
t.Errorf("want error")
} else {
t.Log(err)
}
err = stmt.BindBlob(1, blob[:])
if err == nil {
t.Errorf("want error")
} else {
t.Log(err)
}
err = stmt.BindRawText(1, blob[:])
if err == nil {
t.Errorf("want error")
} else {
t.Log(err)
}
err = stmt.BindZeroBlob(1, 1e9+1)
if err == nil {
t.Errorf("want error")
} else {
t.Log(err)
}
}

View File

@@ -3,7 +3,7 @@ package vfs
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_NAME = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_DEFAULT_SECTOR_SIZE = 4096
)

View File

@@ -7,7 +7,6 @@ import (
type readerVFS struct{}
// Open implements the [vfs.VFS] interface.
func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 {
return nil, flags, sqlite3.CANTOPEN
@@ -20,17 +19,14 @@ func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag,
return nil, flags, sqlite3.CANTOPEN
}
// Delete implements the [vfs.VFS] interface.
func (readerVFS) Delete(name string, dirSync bool) error {
return sqlite3.IOERR_DELETE
}
// Access implements the [vfs.VFS] interface.
func (readerVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return false, nil
}
// FullPathname implements the [vfs.VFS] interface.
func (readerVFS) FullPathname(name string) (string, error) {
return name, nil
}

View File

@@ -45,7 +45,7 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder
}
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_STRING)
name := util.ReadString(mod, zVfsName, _MAX_NAME)
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
return 1
}
@@ -371,7 +371,7 @@ func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags O
if stack[0] == 0 {
return params
}
key := util.ReadString(mod, uint32(stack[0]), _MAX_STRING)
key := util.ReadString(mod, uint32(stack[0]), _MAX_NAME)
if params.Has(key) {
continue
}
@@ -384,7 +384,7 @@ func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags O
if params == nil {
params = url.Values{}
}
params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_STRING))
params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_NAME))
}
}
@@ -392,7 +392,7 @@ func vfsGet(mod api.Module, pVfs uint32) VFS {
var name string
if pVfs != 0 {
const zNameOffset = 16
name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_STRING)
name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_NAME)
}
if vfs := Find(name); vfs != nil {
return vfs

16
vtab.go
View File

@@ -70,7 +70,7 @@ func implements[T any](typ reflect.Type) bool {
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVtab(sql string) error {
// defer c.arena.reset()
// The arena will be cleared by the prepare or exec method.
sqlPtr := c.arena.string(sql)
r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr))
return c.error(r)
@@ -255,7 +255,7 @@ type IndexConstraintUsage struct {
//
// https://sqlite.org/c3ref/vtab_rhs_value.html
func (idx *IndexInfo) RHSValue(column int) (*Value, error) {
// defer idx.c.arena.reset()
// The arena will be cleared by the prepare or exec method.
valPtr := idx.c.arena.new(ptrlen)
r := idx.c.call(idx.c.api.vtabRHSValue,
uint64(idx.handle), uint64(column), uint64(valPtr))
@@ -369,7 +369,7 @@ func vtabModuleCallback(i int) func(_ context.Context, _ api.Module, _, _, _, _,
for i := uint32(0); i < argc; i++ {
ptr := util.ReadUint32(mod, argv+i*ptrlen)
arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING))
arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH))
}
module := vtabGetHandle(ctx, mod, pMod)
@@ -425,13 +425,13 @@ func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, argc, argv,
func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew uint32) uint32 {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer)
err := vtab.Rename(util.ReadString(mod, zNew, _MAX_STRING))
err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME))
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab, nArg, zName, pxFunc uint32) uint32 {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader)
fn, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_STRING))
fn, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME))
if fn != nil {
handle := util.AddHandle(ctx, fn)
util.WriteUint32(mod, pxFunc, handle)
@@ -444,8 +444,8 @@ func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab, nArg, zNam
func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker)
schema := util.ReadString(mod, zSchema, _MAX_STRING)
table := util.ReadString(mod, zTabName, _MAX_STRING)
schema := util.ReadString(mod, zSchema, _MAX_NAME)
table := util.ReadString(mod, zTabName, _MAX_NAME)
err := vtab.Integrity(schema, table, int(mFlags))
// xIntegrity should return OK - even if it finds problems in the content of the virtual table.
// https://sqlite.org/vtab.html#xintegrity
@@ -518,7 +518,7 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idx
args := callbackArgs(db, argc, argv)
var idxName string
if idxStr != 0 {
idxName = util.ReadString(mod, idxStr, _MAX_STRING)
idxName = util.ReadString(mod, idxStr, _MAX_NAME)
}
err := cursor.Filter(int(idxNum), idxName, args...)
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)