Improved error handling.

This commit is contained in:
Nuno Cruces
2023-12-05 16:11:56 +00:00
parent 06d2ff6752
commit 8b45cac16b
4 changed files with 54 additions and 58 deletions

View File

@@ -143,6 +143,8 @@ func errorCode(err error, def ErrorCode) (msg string, code uint32) {
return "", uint32(code)
case ExtendedErrorCode:
return "", uint32(code)
case *Error:
return code.msg, uint32(code.code)
case nil:
return "", _OK
}

View File

@@ -24,7 +24,7 @@ func Register(db *sqlite3.Conn) {
func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(util.ErrorString("wrong number of arguments to function blob_open()"))
ctx.ResultError(util.ErrorString("blob_open: wrong number of arguments"))
return
}

View File

@@ -1,4 +1,4 @@
// Package statement defines virtual tables and table-valued functions natively using SQL.
// Package statement defines table-valued functions natively using SQL.
//
// https://github.com/0x09/sqlite-statement-vtab
package statement
@@ -15,24 +15,53 @@ import (
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
if len(arg) == 0 || len(arg[0]) < 3 {
return nil, fmt.Errorf("statement: no statement provided")
}
sql := arg[0]
if len := len(sql); sql[0] != '(' || sql[len-1] != ')' {
return nil, fmt.Errorf("statement: statement must be parenthesized")
} else {
sql = sql[1 : len-1]
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, fmt.Errorf("statement: wrong number of arguments")
}
table := &table{sql: sql}
err = table.declare(db)
sql := "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(sql)
if err != nil {
table.Close()
return nil, err
}
return table, nil
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
outputs := stmt.ColumnCount()
for i := 0; i < outputs; i++ {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
str.WriteByte(' ')
str.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
inputs := stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
err = db.DeclareVtab(str.String())
if err != nil {
stmt.Close()
return nil, err
}
return &table{sql: sql, stmt: stmt}, nil
}
sqlite3.CreateModule(db, "statement", declare, declare)
@@ -44,49 +73,6 @@ type table struct {
inuse bool
}
func (t *table) declare(db *sqlite3.Conn) (err error) {
var tail string
t.stmt, tail, err = db.Prepare(t.sql)
if err != nil {
return err
}
if tail != "" {
return fmt.Errorf("statement: multiple statements")
}
if !t.stmt.ReadOnly() {
return fmt.Errorf("statement: statement must be read only")
}
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
outputs := t.stmt.ColumnCount()
for i := 0; i < outputs; i++ {
str.WriteString(sep)
name := t.stmt.ColumnName(i)
str.WriteString(sqlite3.QuoteIdentifier(name))
str.WriteByte(' ')
str.WriteString(t.stmt.ColumnDeclType(i))
sep = ","
}
inputs := t.stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := t.stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
return db.DeclareVtab(str.String())
}
func (t *table) Close() error {
return t.stmt.Close()
}

View File

@@ -30,6 +30,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.ReadOnly(); got != false {
t.Error("got true, want false")
}
if got := stmt.BindCount(); got != 1 {
t.Errorf("got %d, want 1", got)
}
@@ -137,6 +141,10 @@ func TestStmt(t *testing.T) {
}
defer stmt.Close()
if got := stmt.ReadOnly(); got != true {
t.Error("got false, want true")
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)