diff --git a/context.go b/context.go index 637ddc2..abee4ec 100644 --- a/context.go +++ b/context.go @@ -89,20 +89,26 @@ func (ctx Context) ResultText(value string) { } // ResultRawText sets the text result of the function to a []byte. -// Returning a nil slice is the same as calling [Context.ResultNull]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultRawText(value []byte) { + if len(value) == 0 { + ctx.ResultText("") + return + } ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_text_go", stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultBlob sets the result of the function to a []byte. -// Returning a nil slice is the same as calling [Context.ResultNull]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultBlob(value []byte) { + if len(value) == 0 { + ctx.ResultZeroBlob(0) + return + } ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_blob_go", stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) diff --git a/sqlite.go b/sqlite.go index df32271..c05a86f 100644 --- a/sqlite.go +++ b/sqlite.go @@ -212,14 +212,10 @@ func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t { } func (sqlt *sqlite) newBytes(b []byte) ptr_t { - if (*[0]byte)(b) == nil { + if len(b) == 0 { return 0 } - size := len(b) - if size == 0 { - size = 1 - } - ptr := sqlt.new(int64(size)) + ptr := sqlt.new(int64(len(b))) util.WriteBytes(sqlt.mod, ptr, b) return ptr } @@ -288,7 +284,7 @@ func (a *arena) new(size int64) ptr_t { } func (a *arena) bytes(b []byte) ptr_t { - if (*[0]byte)(b) == nil { + if len(b) == 0 { return 0 } ptr := a.new(int64(len(b))) diff --git a/sqlite3/sqlite_opt.h b/sqlite3/sqlite_opt.h index c67f271..974efdc 100644 --- a/sqlite3/sqlite_opt.h +++ b/sqlite3/sqlite_opt.h @@ -16,6 +16,9 @@ // #define SQLITE_OMIT_DECLTYPE // #define SQLITE_OMIT_PROGRESS_CALLBACK +// TODO add this: +// #define SQLITE_ENABLE_API_ARMOR + // Other Options #define SQLITE_ALLOW_URI_AUTHORITY diff --git a/sqlite_test.go b/sqlite_test.go index 8321275..122abf6 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -125,11 +125,6 @@ func Test_sqlite_newBytes(t *testing.T) { if got := util.View(sqlite.mod, ptr, int64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } - - ptr = sqlite.newBytes(buf[:0]) - if ptr == 0 { - t.Fatal("got nullptr, want a pointer") - } } func Test_sqlite_newString(t *testing.T) { diff --git a/stmt.go b/stmt.go index c176102..1ea726e 100644 --- a/stmt.go +++ b/stmt.go @@ -265,13 +265,15 @@ func (s *Stmt) BindText(param int, value string) error { // BindRawText binds a []byte to the prepared statement as text. // The leftmost SQL parameter has an index of 1. -// Binding a nil slice is the same as calling [Stmt.BindNull]. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindRawText(param int, value []byte) error { if len(value) > _MAX_LENGTH { return TOOBIG } + if len(value) == 0 { + return s.BindText(param, "") + } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_text_go", stk_t(s.handle), stk_t(param), @@ -281,13 +283,15 @@ func (s *Stmt) BindRawText(param int, value []byte) error { // BindBlob binds a []byte to the prepared statement. // The leftmost SQL parameter has an index of 1. -// Binding a nil slice is the same as calling [Stmt.BindNull]. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindBlob(param int, value []byte) error { if len(value) > _MAX_LENGTH { return TOOBIG } + if len(value) == 0 { + return s.BindZeroBlob(param, 0) + } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_blob_go", stk_t(s.handle), stk_t(param), diff --git a/tests/stmt_test.go b/tests/stmt_test.go index d057405..2ef0bee 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -89,6 +89,13 @@ func TestStmt(t *testing.T) { t.Fatal(err) } + if err := stmt.BindRawText(1, nil); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { + t.Fatal(err) + } + if err := stmt.BindBlob(1, []byte("")); err != nil { t.Fatal(err) } @@ -103,13 +110,6 @@ func TestStmt(t *testing.T) { t.Fatal(err) } - if err := stmt.BindBlob(1, nil); err != nil { - t.Fatal(err) - } - if err := stmt.Exec(); err != nil { - t.Fatal(err) - } - if err := stmt.BindZeroBlob(1, 4); err != nil { t.Fatal(err) } @@ -353,6 +353,31 @@ func TestStmt(t *testing.T) { } } + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.TEXT { + t.Errorf("got %v, want TEXT", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } + } + if stmt.Step() { if got := stmt.ColumnType(0); got != sqlite3.BLOB { t.Errorf("got %v, want BLOB", got) @@ -403,33 +428,6 @@ func TestStmt(t *testing.T) { } } - if stmt.Step() { - if got := stmt.ColumnType(0); got != sqlite3.NULL { - t.Errorf("got %v, want NULL", got) - } - if got := stmt.ColumnBool(0); got != false { - t.Errorf("got %v, want false", got) - } - if got := stmt.ColumnInt(0); got != 0 { - t.Errorf("got %v, want zero", got) - } - if got := stmt.ColumnFloat(0); got != 0 { - t.Errorf("got %v, want zero", got) - } - if got := stmt.ColumnText(0); got != "" { - t.Errorf("got %q, want empty", got) - } - if got := stmt.ColumnBlob(0, nil); got != nil { - t.Errorf("got %q, want nil", got) - } - var got any = 1 - if err := stmt.ColumnJSON(0, &got); err != nil { - t.Error(err) - } else if got != nil { - t.Errorf("got %v, want NULL", got) - } - } - if stmt.Step() { if got := stmt.ColumnType(0); got != sqlite3.BLOB { t.Errorf("got %v, want BLOB", got)