From a9e2cbbfc5ff61778ef3f6ca20d9e8135f1e601b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 3 Nov 2023 07:43:05 -0700 Subject: [PATCH] Quote values, identifiers. --- internal/util/error.go | 1 + quote.go | 112 +++++++++++++++++++++++++++++++++++++++++ sqlite3/func.c | 2 +- sqlite3/main.c | 2 +- sqlite3/progress.c | 2 +- tests/quote_test.go | 82 ++++++++++++++++++++++++++++++ 6 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 quote.go create mode 100644 tests/quote_test.go diff --git a/internal/util/error.go b/internal/util/error.go index 7bcd45a..c2b4e21 100644 --- a/internal/util/error.go +++ b/internal/util/error.go @@ -23,6 +23,7 @@ const ( OffsetErr = ErrorString("sqlite3: invalid offset") TailErr = ErrorString("sqlite3: multiple statements") IsolationErr = ErrorString("sqlite3: unsupported isolation level") + ValueErr = ErrorString("sqlite3: unsupported value") NoVFSErr = ErrorString("sqlite3: no such vfs: ") ) diff --git a/quote.go b/quote.go new file mode 100644 index 0000000..d1cd6fa --- /dev/null +++ b/quote.go @@ -0,0 +1,112 @@ +package sqlite3 + +import ( + "bytes" + "math" + "strconv" + "strings" + "time" + "unsafe" + + "github.com/ncruces/go-sqlite3/internal/util" +) + +// Quote escapes and quotes a value +// making it safe to embed in SQL text. +func Quote(value any) string { + switch v := value.(type) { + case nil: + return "NULL" + case bool: + if v { + return "1" + } else { + return "0" + } + + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case float64: + switch { + case math.IsNaN(v): + return "NULL" + case math.IsInf(v, 1): + return "9.0e999" + case math.IsInf(v, -1): + return "-9.0e999" + } + return strconv.FormatFloat(v, 'g', -1, 64) + case time.Time: + return "'" + v.Format(time.RFC3339Nano) + "'" + + case string: + if strings.IndexByte(v, 0) >= 0 { + break + } + + buf := make([]byte, 2+len(v)+strings.Count(v, "'")) + buf[0] = '\'' + i := 1 + for _, b := range []byte(v) { + if b == '\'' { + buf[i] = b + i += 1 + } + buf[i] = b + i += 1 + } + buf[i] = '\'' + return unsafe.String(&buf[0], len(buf)) + + case []byte: + buf := make([]byte, 3+2*len(v)) + buf[0] = 'x' + buf[1] = '\'' + i := 2 + for _, b := range v { + const hex = "0123456789ABCDEF" + buf[i+0] = hex[b/16] + buf[i+1] = hex[b%16] + i += 2 + } + buf[i] = '\'' + return unsafe.String(&buf[0], len(buf)) + + case ZeroBlob: + if v > ZeroBlob(1e9-3)/2 { + break + } + + buf := bytes.Repeat([]byte("0"), int(3+2*int64(v))) + buf[0] = 'x' + buf[1] = '\'' + buf[len(buf)-1] = '\'' + return unsafe.String(&buf[0], len(buf)) + } + + panic(util.ValueErr) +} + +// QuoteIdentifier escapes and quotes an identifier +// making it safe to embed in SQL text. +func QuoteIdentifier(id string) string { + if strings.IndexByte(id, 0) >= 0 { + panic(util.ValueErr) + } + + buf := make([]byte, 2+len(id)+strings.Count(id, `"`)) + buf[0] = '"' + i := 1 + for _, b := range []byte(id) { + if b == '"' { + buf[i] = b + i += 1 + } + buf[i] = b + i += 1 + } + buf[i] = '"' + return unsafe.String(&buf[0], len(buf)) +} diff --git a/sqlite3/func.c b/sqlite3/func.c index 550e165..8451033 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -38,4 +38,4 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg, void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) { sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy); -} +} \ No newline at end of file diff --git a/sqlite3/main.c b/sqlite3/main.c index 894464b..81b2e21 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -23,4 +23,4 @@ __attribute__((constructor)) void init() { sqlite3_auto_extension((void (*)(void))sqlite3_uint_init); sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init); sqlite3_auto_extension((void (*)(void))sqlite3_time_init); -} +} \ No newline at end of file diff --git a/sqlite3/progress.c b/sqlite3/progress.c index 781b1a3..11c1551 100644 --- a/sqlite3/progress.c +++ b/sqlite3/progress.c @@ -6,4 +6,4 @@ int go_progress(void *); void sqlite3_progress_handler_go(sqlite3 *db, int n) { sqlite3_progress_handler(db, n, go_progress, /*arg=*/NULL); -} +} \ No newline at end of file diff --git a/tests/quote_test.go b/tests/quote_test.go new file mode 100644 index 0000000..bd322d6 --- /dev/null +++ b/tests/quote_test.go @@ -0,0 +1,82 @@ +package tests + +import ( + "math" + "reflect" + "testing" + "time" + + "github.com/ncruces/go-sqlite3" +) + +func TestQuote(t *testing.T) { + tests := []struct { + val any + want string + }{ + {`abc`, "'abc'"}, + {`a"bc`, "'a\"bc'"}, + {`a'bc`, "'a''bc'"}, + {"\x07bc", "'\abc'"}, + {"\x1c\n", "'\x1c\n'"}, + {[]byte("\xB0\x00\x0B"), "x'B0000B'"}, + {"\xB0\x00\x0B", ""}, + + {0, "0"}, + {true, "1"}, + {false, "0"}, + {nil, "NULL"}, + {math.NaN(), "NULL"}, + {math.Inf(1), "9.0e999"}, + {math.Inf(-1), "-9.0e999"}, + {math.Pi, "3.141592653589793"}, + {int64(math.MaxInt64), "9223372036854775807"}, + {time.Unix(0, 0).UTC(), "'1970-01-01T00:00:00Z'"}, + {sqlite3.ZeroBlob(4), "x'00000000'"}, + {sqlite3.ZeroBlob(1e9), ""}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && tt.want != "" { + t.Errorf("Quote(%q) = %v", tt.val, r) + } + }() + + got := sqlite3.Quote(tt.val) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Quote(%v) = %q, want %q", tt.val, got, tt.want) + } + }) + } +} + +func TestQuoteIdentifier(t *testing.T) { + tests := []struct { + id string + want string + }{ + {`abc`, `"abc"`}, + {`a"bc`, `"a""bc"`}, + {`a'bc`, `"a'bc"`}, + {"\x07bc", "\"\abc\""}, + {"\x1c\n", "\"\x1c\n\""}, + {"\xB0\x00\x0B", ""}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && tt.want != "" { + t.Errorf("QuoteIdentifier(%q) = %v", tt.id, r) + } + }() + + got := sqlite3.QuoteIdentifier(tt.id) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("QuoteIdentifier(%v) = %q, want %q", tt.id, got, tt.want) + } + }) + } +}