From 4ac2ccf47300d474ad10c2a5d8828b6eb08a14e1 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 18 Feb 2023 00:47:56 +0000 Subject: [PATCH] Named parameters. --- api.go | 8 ++++++-- conn.go | 4 ++-- const.go | 1 + embed/build.sh | 4 +++- embed/sqlite3.wasm | Bin 726666 -> 726846 bytes stmt.go | 37 +++++++++++++++++++++++++++++++++++-- stmt_test.go | 32 ++++++++++++++++++++++++++++++++ 7 files changed, 79 insertions(+), 7 deletions(-) diff --git a/api.go b/api.go index 294a367..76f91bf 100644 --- a/api.go +++ b/api.go @@ -45,12 +45,14 @@ func newConn(ctx context.Context, module api.Module) (_ *Conn, err error) { exec: getFun("sqlite3_exec"), clearBindings: getFun("sqlite3_clear_bindings"), bindCount: getFun("sqlite3_bind_parameter_count"), + bindIndex: getFun("sqlite3_bind_parameter_index"), + bindName: getFun("sqlite3_bind_parameter_name"), + bindNull: getFun("sqlite3_bind_null"), bindInteger: getFun("sqlite3_bind_int64"), bindFloat: getFun("sqlite3_bind_double"), bindText: getFun("sqlite3_bind_text64"), bindBlob: getFun("sqlite3_bind_blob64"), bindZeroBlob: getFun("sqlite3_bind_zeroblob64"), - bindNull: getFun("sqlite3_bind_null"), columnCount: getFun("sqlite3_column_count"), columnName: getFun("sqlite3_column_name"), columnType: getFun("sqlite3_column_type"), @@ -86,13 +88,15 @@ type sqliteAPI struct { step api.Function exec api.Function clearBindings api.Function + bindNull api.Function bindCount api.Function + bindIndex api.Function + bindName api.Function bindInteger api.Function bindFloat api.Function bindText api.Function bindBlob api.Function bindZeroBlob api.Function - bindNull api.Function columnCount api.Function columnName api.Function columnType api.Function diff --git a/conn.go b/conn.go index 95f425d..774506b 100644 --- a/conn.go +++ b/conn.go @@ -236,12 +236,12 @@ func (c *Conn) error(rc uint64, sql ...string) error { r, _ = c.api.errstr.Call(c.ctx, rc) if r != nil { - err.str = c.mem.readString(uint32(r[0]), 512) + err.str = c.mem.readString(uint32(r[0]), _MAX_STRING) } r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) if r != nil { - err.msg = c.mem.readString(uint32(r[0]), 512) + err.msg = c.mem.readString(uint32(r[0]), _MAX_STRING) } if sql != nil { diff --git a/const.go b/const.go index a5f2ddf..e79da4f 100644 --- a/const.go +++ b/const.go @@ -9,6 +9,7 @@ const ( _UTF8 = 1 + _MAX_STRING = 512 // Used for short strings: names, error messages… _MAX_PATHNAME = 512 ptrlen = 4 diff --git a/embed/build.sh b/embed/build.sh index ee8b186..b6448c9 100755 --- a/embed/build.sh +++ b/embed/build.sh @@ -29,12 +29,14 @@ zig cc --target=wasm32-wasi -flto -g0 -Os \ -Wl,--export=sqlite3_exec \ -Wl,--export=sqlite3_clear_bindings \ -Wl,--export=sqlite3_bind_parameter_count \ + -Wl,--export=sqlite3_bind_parameter_index \ + -Wl,--export=sqlite3_bind_parameter_name \ + -Wl,--export=sqlite3_bind_null \ -Wl,--export=sqlite3_bind_int64 \ -Wl,--export=sqlite3_bind_double \ -Wl,--export=sqlite3_bind_text64 \ -Wl,--export=sqlite3_bind_blob64 \ -Wl,--export=sqlite3_bind_zeroblob64 \ - -Wl,--export=sqlite3_bind_null \ -Wl,--export=sqlite3_column_count \ -Wl,--export=sqlite3_column_name \ -Wl,--export=sqlite3_column_type \ diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 59fae1433f53e87c3c0f78a04be136d163cb0ad9..b0c688ca8df3dd8a0dbf40523b694202073b6c5e 100755 GIT binary patch delta 459 zcmZ9IIZFdk6ouc-jLUdmbR=)k;1*b9nFYh@*Z#`p|wSbT8rIS20LY#j}r9}k}773Bq=i~GuBjZ0B>vf7Zc zUF$ENYEJ-d;=zM0^=2r_$s4v;x4@zq;j5&K)LC9yjPtQszRJVo9UFgk+m1T(q=LHz z^r$33y<(Z6w~r)g^LHk^5Hjbv6C5Y&)cqqaSq>+u9Uo;%nlb?fvI&x*P2T}-+hHrh zNtXUS=EQB1`wdw!ZqBZZl8s`9wyNs=YwvTapyvuYNL&&{qDpEcwURnXeW{>plV=Qx zkN7~(Yx)R&&3mQWOKtkXeNzVYvk5pX!)7Tc8m`tLP_2Z0%EMtR8V#8dSx3yW<7IK@ zL9bCQ>!0;5+zdOQA_f?Cxe`>&Fzj|EsfaP`7uO6?958H!l#u^)LR#ENA2OoC&<|PR SH_!(;F=b#EY?S5%UzzD=lK+FupEI`Z(#B4y!4#XVW9|>?e?%i&E zj??`lZ{uG^mb}E=RK|;c7~8k^Pv^Q(GyRYVw+!Ro>Gwsr<(d9*PZQ-_l%GfgfrYN@pW9#%^qTKR~ LZQG^9xK$Vd*tdr9 diff --git a/stmt.go b/stmt.go index 029ee1f..7f1d93e 100644 --- a/stmt.go +++ b/stmt.go @@ -94,7 +94,7 @@ func (s *Stmt) Exec() error { return s.Reset() } -// BindCount gets the number of SQL parameters in a prepared statement. +// BindCount returns the number of SQL parameters in the prepared statement. // // https://www.sqlite.org/c3ref/bind_parameter_count.html func (s *Stmt) BindCount() int { @@ -106,6 +106,39 @@ func (s *Stmt) BindCount() int { return int(r[0]) } +// BindIndex returns the index of a parameter in the prepared statement +// given its name. +// +// https://www.sqlite.org/c3ref/bind_parameter_index.html +func (s *Stmt) BindIndex(name string) int { + defer s.c.arena.reset() + namePtr := s.c.arena.string(name) + r, err := s.c.api.bindIndex.Call(s.c.ctx, + uint64(s.handle), uint64(namePtr)) + if err != nil { + panic(err) + } + return int(r[0]) +} + +// BindName returns the name of a parameter in the prepared statement. +// The leftmost SQL parameter has an index of 1. +// +// https://www.sqlite.org/c3ref/bind_parameter_name.html +func (s *Stmt) BindName(param int) string { + r, err := s.c.api.bindName.Call(s.c.ctx, + uint64(s.handle), uint64(param)) + if err != nil { + panic(err) + } + + ptr := uint32(r[0]) + if ptr == 0 { + return "" + } + return s.c.mem.readString(ptr, _MAX_STRING) +} + // BindBool binds a bool to the prepared statement. // The leftmost SQL parameter has an index of 1. // SQLite does not have a separate boolean storage class. @@ -226,7 +259,7 @@ func (s *Stmt) ColumnName(col int) string { if ptr == 0 { return "" } - return s.c.mem.readString(ptr, 512) + return s.c.mem.readString(ptr, _MAX_STRING) } // ColumnType returns the initial [Datatype] of the result column. diff --git a/stmt_test.go b/stmt_test.go index eee9769..8d8cb91 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -366,3 +366,35 @@ func TestStmt_Close(t *testing.T) { var stmt *Stmt stmt.Close() } + +func TestStmt_BindName(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + want := []string{"", "", "", "", "?5", ":AAA", "@AAA", "$AAA"} + stmt, _, err := db.Prepare(`SELECT ?, ?5, :AAA, @AAA, $AAA`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if got := stmt.BindCount(); got != len(want) { + t.Errorf("got %d, want %d", got, len(want)) + } + + for i, name := range want { + id := i + 1 + if got := stmt.BindName(id); got != name { + t.Errorf("got %q, want %q", got, name) + } + if name == "" { + id = 0 + } + if got := stmt.BindIndex(name); got != id { + t.Errorf("got %d, want %d", got, id) + } + } +}