mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Quote values, identifiers.
This commit is contained in:
@@ -23,6 +23,7 @@ const (
|
||||
OffsetErr = ErrorString("sqlite3: invalid offset")
|
||||
TailErr = ErrorString("sqlite3: multiple statements")
|
||||
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
|
||||
ValueErr = ErrorString("sqlite3: unsupported value")
|
||||
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
|
||||
)
|
||||
|
||||
|
||||
112
quote.go
Normal file
112
quote.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Quote escapes and quotes a value
|
||||
// making it safe to embed in SQL text.
|
||||
func Quote(value any) string {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "NULL"
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
} else {
|
||||
return "0"
|
||||
}
|
||||
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case float64:
|
||||
switch {
|
||||
case math.IsNaN(v):
|
||||
return "NULL"
|
||||
case math.IsInf(v, 1):
|
||||
return "9.0e999"
|
||||
case math.IsInf(v, -1):
|
||||
return "-9.0e999"
|
||||
}
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
return "'" + v.Format(time.RFC3339Nano) + "'"
|
||||
|
||||
case string:
|
||||
if strings.IndexByte(v, 0) >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
|
||||
buf[0] = '\''
|
||||
i := 1
|
||||
for _, b := range []byte(v) {
|
||||
if b == '\'' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case []byte:
|
||||
buf := make([]byte, 3+2*len(v))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
i := 2
|
||||
for _, b := range v {
|
||||
const hex = "0123456789ABCDEF"
|
||||
buf[i+0] = hex[b/16]
|
||||
buf[i+1] = hex[b%16]
|
||||
i += 2
|
||||
}
|
||||
buf[i] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
|
||||
case ZeroBlob:
|
||||
if v > ZeroBlob(1e9-3)/2 {
|
||||
break
|
||||
}
|
||||
|
||||
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
|
||||
buf[0] = 'x'
|
||||
buf[1] = '\''
|
||||
buf[len(buf)-1] = '\''
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier
|
||||
// making it safe to embed in SQL text.
|
||||
func QuoteIdentifier(id string) string {
|
||||
if strings.IndexByte(id, 0) >= 0 {
|
||||
panic(util.ValueErr)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
|
||||
buf[0] = '"'
|
||||
i := 1
|
||||
for _, b := range []byte(id) {
|
||||
if b == '"' {
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = b
|
||||
i += 1
|
||||
}
|
||||
buf[i] = '"'
|
||||
return unsafe.String(&buf[0], len(buf))
|
||||
}
|
||||
82
tests/quote_test.go
Normal file
82
tests/quote_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
val any
|
||||
want string
|
||||
}{
|
||||
{`abc`, "'abc'"},
|
||||
{`a"bc`, "'a\"bc'"},
|
||||
{`a'bc`, "'a''bc'"},
|
||||
{"\x07bc", "'\abc'"},
|
||||
{"\x1c\n", "'\x1c\n'"},
|
||||
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
|
||||
{0, "0"},
|
||||
{true, "1"},
|
||||
{false, "0"},
|
||||
{nil, "NULL"},
|
||||
{math.NaN(), "NULL"},
|
||||
{math.Inf(1), "9.0e999"},
|
||||
{math.Inf(-1), "-9.0e999"},
|
||||
{math.Pi, "3.141592653589793"},
|
||||
{int64(math.MaxInt64), "9223372036854775807"},
|
||||
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
|
||||
{sqlite3.ZeroBlob(4), "x'00000000'"},
|
||||
{sqlite3.ZeroBlob(1e9), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("Quote(%q) = %v", tt.val, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.Quote(tt.val)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{`abc`, `"abc"`},
|
||||
{`a"bc`, `"a""bc"`},
|
||||
{`a'bc`, `"a'bc"`},
|
||||
{"\x07bc", "\"\abc\""},
|
||||
{"\x1c\n", "\"\x1c\n\""},
|
||||
{"\xB0\x00\x0B", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && tt.want != "" {
|
||||
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
|
||||
}
|
||||
}()
|
||||
|
||||
got := sqlite3.QuoteIdentifier(tt.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user