Improved quoting.

This commit is contained in:
Nuno Cruces
2024-10-07 14:42:09 +01:00
parent 3469460635
commit 911e497891
4 changed files with 47 additions and 11 deletions

View File

@@ -6,7 +6,7 @@
// import _ "github.com/ncruces/go-sqlite3/embed/bcw2" // import _ "github.com/ncruces/go-sqlite3/embed/bcw2"
// //
// [BEGIN CONCURRENT]: https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md // [BEGIN CONCURRENT]: https://sqlite.org/src/doc/begin-concurrent/doc/begin_concurrent.md
// [Wal2]: https://www.sqlite.org/cgi/src/doc/wal2/doc/wal2.md // [Wal2]: https://sqlite.org/cgi/src/doc/wal2/doc/wal2.md
package bcw2 package bcw2
import ( import (

View File

@@ -3,6 +3,7 @@ package sqlite3
import ( import (
"bytes" "bytes"
"math" "math"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -13,6 +14,9 @@ import (
// Quote escapes and quotes a value // Quote escapes and quotes a value
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters are truncated.
//
// https://sqlite.org/lang_corefunc.html#quote
func Quote(value any) string { func Quote(value any) string {
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
@@ -42,8 +46,8 @@ func Quote(value any) string {
return "'" + v.Format(time.RFC3339Nano) + "'" return "'" + v.Format(time.RFC3339Nano) + "'"
case string: case string:
if strings.IndexByte(v, 0) >= 0 { if i := strings.IndexByte(v, 0); i >= 0 {
break v = v[:i]
} }
buf := make([]byte, 2+len(v)+strings.Count(v, "'")) buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
@@ -75,10 +79,6 @@ func Quote(value any) string {
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
case ZeroBlob: case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v))) buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[1] = '\'' buf[1] = '\''
buf[0] = 'x' buf[0] = 'x'
@@ -86,11 +86,39 @@ func Quote(value any) string {
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
} }
v := reflect.ValueOf(value)
k := v.Kind()
if k == reflect.Interface || k == reflect.Pointer {
if v.IsNil() {
return "NULL"
}
v = v.Elem()
k = v.Kind()
}
switch {
case v.CanInt():
return strconv.FormatInt(v.Int(), 10)
case v.CanUint():
return strconv.FormatUint(v.Uint(), 10)
case v.CanFloat():
return Quote(v.Float())
case k == reflect.Bool:
return Quote(v.Bool())
case k == reflect.String:
return Quote(v.String())
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
return Quote(v.Bytes())
}
panic(util.ValueErr) panic(util.ValueErr)
} }
// QuoteIdentifier escapes and quotes an identifier // QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters panic.
func QuoteIdentifier(id string) string { func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 { if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr) panic(util.ValueErr)

View File

@@ -379,7 +379,7 @@ func (s *Stmt) BindValue(param int, value Value) error {
// DataCount resets the number of columns in a result set. // DataCount resets the number of columns in a result set.
// //
// https://www.sqlite.org/c3ref/data_count.html // https://sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int { func (s *Stmt) DataCount() int {
r := s.c.call("sqlite3_data_count", r := s.c.call("sqlite3_data_count",
uint64(s.handle)) uint64(s.handle))

View File

@@ -1,6 +1,8 @@
package tests package tests
import ( import (
"database/sql"
"encoding/json"
"math" "math"
"reflect" "reflect"
"testing" "testing"
@@ -19,8 +21,8 @@ func TestQuote(t *testing.T) {
{`a'bc`, "'a''bc'"}, {`a'bc`, "'a''bc'"},
{"\x07bc", "'\abc'"}, {"\x07bc", "'\abc'"},
{"\x1c\n", "'\x1c\n'"}, {"\x1c\n", "'\x1c\n'"},
{"\xB0\x00\x0B", "'\xB0'"},
{[]byte("\xB0\x00\x0B"), "x'B0000B'"}, {[]byte("\xB0\x00\x0B"), "x'B0000B'"},
{"\xB0\x00\x0B", ""},
{0, "0"}, {0, "0"},
{true, "1"}, {true, "1"},
@@ -33,7 +35,13 @@ func TestQuote(t *testing.T) {
{int64(math.MaxInt64), "9223372036854775807"}, {int64(math.MaxInt64), "9223372036854775807"},
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"}, {time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
{sqlite3.ZeroBlob(4), "x'00000000'"}, {sqlite3.ZeroBlob(4), "x'00000000'"},
{sqlite3.ZeroBlob(1e9), ""}, {int8(0), "0"},
{uint(0), "0"},
{float32(0), "0"},
{(*string)(nil), "NULL"},
{json.Number("0"), "'0'"},
{&sql.RawBytes{'0'}, "x'30'"},
{t, ""}, // panic
} }
for _, tt := range tests { for _, tt := range tests {
@@ -62,7 +70,7 @@ func TestQuoteIdentifier(t *testing.T) {
{`a'bc`, `"a'bc"`}, {`a'bc`, `"a'bc"`},
{"\x07bc", "\"\abc\""}, {"\x07bc", "\"\abc\""},
{"\x1c\n", "\"\x1c\n\""}, {"\x1c\n", "\"\x1c\n\""},
{"\xB0\x00\x0B", ""}, {"\xB0\x00\x0B", ""}, // panic
} }
for _, tt := range tests { for _, tt := range tests {