Quote values, identifiers.

This commit is contained in:
Nuno Cruces
2023-11-03 07:43:05 -07:00
parent a7c00eb150
commit a9e2cbbfc5
6 changed files with 198 additions and 3 deletions

View File

@@ -23,6 +23,7 @@ const (
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
ValueErr = ErrorString("sqlite3: unsupported value")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)

112
quote.go Normal file
View File

@@ -0,0 +1,112 @@
package sqlite3
import (
"bytes"
"math"
"strconv"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Quote escapes and quotes a value
// making it safe to embed in SQL text.
func Quote(value any) string {
switch v := value.(type) {
case nil:
return "NULL"
case bool:
if v {
return "1"
} else {
return "0"
}
case int:
return strconv.Itoa(v)
case int64:
return strconv.FormatInt(v, 10)
case float64:
switch {
case math.IsNaN(v):
return "NULL"
case math.IsInf(v, 1):
return "9.0e999"
case math.IsInf(v, -1):
return "-9.0e999"
}
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return "'" + v.Format(time.RFC3339Nano) + "'"
case string:
if strings.IndexByte(v, 0) >= 0 {
break
}
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
buf[0] = '\''
i := 1
for _, b := range []byte(v) {
if b == '\'' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case []byte:
buf := make([]byte, 3+2*len(v))
buf[0] = 'x'
buf[1] = '\''
i := 2
for _, b := range v {
const hex = "0123456789ABCDEF"
buf[i+0] = hex[b/16]
buf[i+1] = hex[b%16]
i += 2
}
buf[i] = '\''
return unsafe.String(&buf[0], len(buf))
case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[0] = 'x'
buf[1] = '\''
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
}
panic(util.ValueErr)
}
// QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text.
func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr)
}
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
buf[0] = '"'
i := 1
for _, b := range []byte(id) {
if b == '"' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[i] = '"'
return unsafe.String(&buf[0], len(buf))
}

82
tests/quote_test.go Normal file
View File

@@ -0,0 +1,82 @@
package tests
import (
"math"
"reflect"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
)
func TestQuote(t *testing.T) {
tests := []struct {
val any
want string
}{
{`abc`, "'abc'"},
{`a"bc`, "'a\"bc'"},
{`a'bc`, "'a''bc'"},
{"\x07bc", "'\abc'"},
{"\x1c\n", "'\x1c\n'"},
{[]byte("\xB0\x00\x0B"), "x'B0000B'"},
{"\xB0\x00\x0B", ""},
{0, "0"},
{true, "1"},
{false, "0"},
{nil, "NULL"},
{math.NaN(), "NULL"},
{math.Inf(1), "9.0e999"},
{math.Inf(-1), "-9.0e999"},
{math.Pi, "3.141592653589793"},
{int64(math.MaxInt64), "9223372036854775807"},
{time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"},
{sqlite3.ZeroBlob(4), "x'00000000'"},
{sqlite3.ZeroBlob(1e9), ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("Quote(%q) = %v", tt.val, r)
}
}()
got := sqlite3.Quote(tt.val)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want)
}
})
}
}
func TestQuoteIdentifier(t *testing.T) {
tests := []struct {
id string
want string
}{
{`abc`, `"abc"`},
{`a"bc`, `"a""bc"`},
{`a'bc`, `"a'bc"`},
{"\x07bc", "\"\abc\""},
{"\x1c\n", "\"\x1c\n\""},
{"\xB0\x00\x0B", ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && tt.want != "" {
t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r)
}
}()
got := sqlite3.QuoteIdentifier(tt.id)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want)
}
})
}
}