diff --git a/embed/bcw2/init.go b/embed/bcw2/init.go index 4c674ec..fe04e79 100644 --- a/embed/bcw2/init.go +++ b/embed/bcw2/init.go @@ -6,7 +6,7 @@ // import _ "github.com/ncruces/go-sqlite3/embed/bcw2" // // [BEGIN CONCURRENT]: https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md -// [Wal2]: https://www.sqlite.org/cgi/src/doc/wal2/doc/wal2.md +// [Wal2]: https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md package bcw2 import ( diff --git a/quote.go b/quote.go index 8d1be0c..abe516d 100644 --- a/quote.go +++ b/quote.go @@ -3,6 +3,7 @@ package sqlite3 import ( "bytes" "math" + "reflect" "strconv" "strings" "time" @@ -13,6 +14,9 @@ import ( // Quote escapes and quotes a value // making it safe to embed in SQL text. +// Strings with embedded NUL characters are truncated. +// +// https://sqlite.org/lang_corefunc.html#quote func Quote(value any) string { switch v := value.(type) { case nil: @@ -42,8 +46,8 @@ func Quote(value any) string { return "'" + v.Format(time.RFC3339Nano) + "'" case string: - if strings.IndexByte(v, 0) >= 0 { - break + if i := strings.IndexByte(v, 0); i >= 0 { + v = v[:i] } buf := make([]byte, 2+len(v)+strings.Count(v, "'")) @@ -75,10 +79,6 @@ func Quote(value any) string { return unsafe.String(&buf[0], len(buf)) case ZeroBlob: - if v > ZeroBlob(1e9-3)/2 { - break - } - buf := bytes.Repeat([]byte("0"), int(3+2*int64(v))) buf[1] = '\'' buf[0] = 'x' @@ -86,11 +86,39 @@ func Quote(value any) string { return unsafe.String(&buf[0], len(buf)) } + v := reflect.ValueOf(value) + k := v.Kind() + + if k == reflect.Interface || k == reflect.Pointer { + if v.IsNil() { + return "NULL" + } + v = v.Elem() + k = v.Kind() + } + + switch { + case v.CanInt(): + return strconv.FormatInt(v.Int(), 10) + case v.CanUint(): + return strconv.FormatUint(v.Uint(), 10) + case v.CanFloat(): + return Quote(v.Float()) + case k == reflect.Bool: + return Quote(v.Bool()) + case k == reflect.String: + return Quote(v.String()) + case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) && + v.Type().Elem().Kind() == reflect.Uint8: + return Quote(v.Bytes()) + } + panic(util.ValueErr) } // QuoteIdentifier escapes and quotes an identifier // making it safe to embed in SQL text. +// Strings with embedded NUL characters panic. func QuoteIdentifier(id string) string { if strings.IndexByte(id, 0) >= 0 { panic(util.ValueErr) diff --git a/stmt.go b/stmt.go index 82f9fb7..9da2a2e 100644 --- a/stmt.go +++ b/stmt.go @@ -379,7 +379,7 @@ func (s *Stmt) BindValue(param int, value Value) error { // DataCount resets the number of columns in a result set. // -// https://www.sqlite.org/c3ref/data_count.html +// https://sqlite.org/c3ref/data_count.html func (s *Stmt) DataCount() int { r := s.c.call("sqlite3_data_count", uint64(s.handle)) diff --git a/tests/quote_test.go b/tests/quote_test.go index bd322d6..095d83e 100644 --- a/tests/quote_test.go +++ b/tests/quote_test.go @@ -1,6 +1,8 @@ package tests import ( + "database/sql" + "encoding/json" "math" "reflect" "testing" @@ -19,8 +21,8 @@ func TestQuote(t *testing.T) { {`a'bc`, "'a''bc'"}, {"\x07bc", "'\abc'"}, {"\x1c\n", "'\x1c\n'"}, + {"\xB0\x00\x0B", "'\xB0'"}, {[]byte("\xB0\x00\x0B"), "x'B0000B'"}, - {"\xB0\x00\x0B", ""}, {0, "0"}, {true, "1"}, @@ -33,7 +35,13 @@ func TestQuote(t *testing.T) { {int64(math.MaxInt64), "9223372036854775807"}, {time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"}, {sqlite3.ZeroBlob(4), "x'00000000'"}, - {sqlite3.ZeroBlob(1e9), ""}, + {int8(0), "0"}, + {uint(0), "0"}, + {float32(0), "0"}, + {(*string)(nil), "NULL"}, + {json.Number("0"), "'0'"}, + {&sql.RawBytes{'0'}, "x'30'"}, + {t, ""}, // panic } for _, tt := range tests { @@ -62,7 +70,7 @@ func TestQuoteIdentifier(t *testing.T) { {`a'bc`, `"a'bc"`}, {"\x07bc", "\"\abc\""}, {"\x1c\n", "\"\x1c\n\""}, - {"\xB0\x00\x0B", ""}, + {"\xB0\x00\x0B", ""}, // panic } for _, tt := range tests {