diff --git a/error.go b/error.go index 9048272..ca86b66 100644 --- a/error.go +++ b/error.go @@ -143,6 +143,8 @@ func errorCode(err error, def ErrorCode) (msg string, code uint32) { return "", uint32(code) case ExtendedErrorCode: return "", uint32(code) + case *Error: + return code.msg, uint32(code.code) case nil: return "", _OK } diff --git a/ext/blob/blob.go b/ext/blob/blob.go index d0093f1..82bd78d 100644 --- a/ext/blob/blob.go +++ b/ext/blob/blob.go @@ -24,7 +24,7 @@ func Register(db *sqlite3.Conn) { func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) { if len(arg) < 6 { - ctx.ResultError(util.ErrorString("wrong number of arguments to function blob_open()")) + ctx.ResultError(util.ErrorString("blob_open: wrong number of arguments")) return } diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go index 81ef4e3..951280a 100644 --- a/ext/statement/stmt.go +++ b/ext/statement/stmt.go @@ -1,4 +1,4 @@ -// Package statement defines virtual tables and table-valued functions natively using SQL. +// Package statement defines table-valued functions natively using SQL. // // https://github.com/0x09/sqlite-statement-vtab package statement @@ -15,24 +15,53 @@ import ( // Register registers the statement virtual table. func Register(db *sqlite3.Conn) { - declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { - if len(arg) == 0 || len(arg[0]) < 3 { - return nil, fmt.Errorf("statement: no statement provided") - } - sql := arg[0] - if len := len(sql); sql[0] != '(' || sql[len-1] != ')' { - return nil, fmt.Errorf("statement: statement must be parenthesized") - } else { - sql = sql[1 : len-1] + declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) { + if len(arg) != 1 { + return nil, fmt.Errorf("statement: wrong number of arguments") } - table := &table{sql: sql} - err = table.declare(db) + sql := "SELECT * FROM\n" + arg[0] + + stmt, _, err := db.Prepare(sql) if err != nil { - table.Close() return nil, err } - return table, nil + + var sep = "" + var str strings.Builder + str.WriteString(`CREATE TABLE x(`) + outputs := stmt.ColumnCount() + for i := 0; i < outputs; i++ { + name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) + str.WriteString(sep) + str.WriteString(name) + str.WriteByte(' ') + str.WriteString(stmt.ColumnDeclType(i)) + sep = "," + } + inputs := stmt.BindCount() + for i := 1; i <= inputs; i++ { + str.WriteString(sep) + name := stmt.BindName(i) + if name == "" { + str.WriteString("[") + str.WriteString(strconv.Itoa(i)) + str.WriteString("] HIDDEN") + } else { + str.WriteString(sqlite3.QuoteIdentifier(name[1:])) + str.WriteString(" HIDDEN") + } + sep = "," + } + str.WriteByte(')') + + err = db.DeclareVtab(str.String()) + if err != nil { + stmt.Close() + return nil, err + } + + return &table{sql: sql, stmt: stmt}, nil } sqlite3.CreateModule(db, "statement", declare, declare) @@ -44,49 +73,6 @@ type table struct { inuse bool } -func (t *table) declare(db *sqlite3.Conn) (err error) { - var tail string - t.stmt, tail, err = db.Prepare(t.sql) - if err != nil { - return err - } - if tail != "" { - return fmt.Errorf("statement: multiple statements") - } - if !t.stmt.ReadOnly() { - return fmt.Errorf("statement: statement must be read only") - } - - var sep = "" - var str strings.Builder - str.WriteString(`CREATE TABLE x(`) - outputs := t.stmt.ColumnCount() - for i := 0; i < outputs; i++ { - str.WriteString(sep) - name := t.stmt.ColumnName(i) - str.WriteString(sqlite3.QuoteIdentifier(name)) - str.WriteByte(' ') - str.WriteString(t.stmt.ColumnDeclType(i)) - sep = "," - } - inputs := t.stmt.BindCount() - for i := 1; i <= inputs; i++ { - str.WriteString(sep) - name := t.stmt.BindName(i) - if name == "" { - str.WriteString("[") - str.WriteString(strconv.Itoa(i)) - str.WriteString("] HIDDEN") - } else { - str.WriteString(sqlite3.QuoteIdentifier(name[1:])) - str.WriteString(" HIDDEN") - } - sep = "," - } - str.WriteByte(')') - return db.DeclareVtab(str.String()) -} - func (t *table) Close() error { return t.stmt.Close() } diff --git a/tests/stmt_test.go b/tests/stmt_test.go index a39acc7..0c0250c 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -30,6 +30,10 @@ func TestStmt(t *testing.T) { } defer stmt.Close() + if got := stmt.ReadOnly(); got != false { + t.Error("got true, want false") + } + if got := stmt.BindCount(); got != 1 { t.Errorf("got %d, want 1", got) } @@ -137,6 +141,10 @@ func TestStmt(t *testing.T) { } defer stmt.Close() + if got := stmt.ReadOnly(); got != true { + t.Error("got false, want true") + } + if stmt.Step() { if got := stmt.ColumnType(0); got != sqlite3.INTEGER { t.Errorf("got %v, want INTEGER", got)