UUID extension (#113)

This commit is contained in:
Nuno Cruces
2024-07-04 15:28:49 +01:00
committed by GitHub
parent 72f8ad0f14
commit da6e4d8b86
11 changed files with 363 additions and 16 deletions

Binary file not shown.

View File

@@ -7,7 +7,6 @@
package bloom
import (
"errors"
"fmt"
"io"
"math"
@@ -15,6 +14,7 @@ import (
"github.com/dchest/siphash"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the bloom_filter virtual table:
@@ -47,7 +47,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if nelem <= 0 {
return nil, errors.New("bloom: number of elements in filter must be positive")
return nil, util.ErrorString("bloom: number of elements in filter must be positive")
}
} else {
nelem = 100
@@ -59,7 +59,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if t.prob <= 0 || t.prob >= 1 {
return nil, errors.New("bloom: probability must be in the range (0,1)")
return nil, util.ErrorString("bloom: probability must be in the range (0,1)")
}
} else {
t.prob = 0.01
@@ -71,7 +71,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if t.hashes <= 0 {
return nil, errors.New("bloom: number of hash functions must be positive")
return nil, util.ErrorString("bloom: number of hash functions must be positive")
}
} else {
t.hashes = max(1, numHashes(t.prob))
@@ -171,7 +171,7 @@ func (t *bloom) Integrity(schema, table string, flags int) error {
}
defer load.Close()
err = errors.New("bloom: invalid parameters")
err = util.ErrorString("bloom: invalid parameters")
if !load.Step() {
return err
}
@@ -213,9 +213,9 @@ func (b *bloom) BestIndex(idx *sqlite3.IndexInfo) error {
func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
if arg[0].Type() != sqlite3.NULL {
if len(arg) == 1 {
return 0, errors.New("bloom: elements cannot be deleted")
return 0, util.ErrorString("bloom: elements cannot be deleted")
}
return 0, errors.New("bloom: elements cannot be updated")
return 0, util.ErrorString("bloom: elements cannot be updated")
}
blob := arg[2].RawBlob()

View File

@@ -9,7 +9,6 @@ package csv
import (
"bufio"
"encoding/csv"
"errors"
"fmt"
"io"
"io/fs"
@@ -17,6 +16,7 @@ import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
@@ -73,7 +73,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
}
if (filename == "") == (data == "") {
return nil, errors.New(`csv: must specify either "filename" or "data" but not both`)
return nil, util.ErrorString(`csv: must specify either "filename" or "data" but not both`)
}
table := &table{

View File

@@ -9,6 +9,7 @@ import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the pivot virtual table.
@@ -65,7 +66,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err
}
if stmt.ColumnCount() != 2 {
return nil, errors.New("pivot: column definition query expects 2 result columns")
return nil, util.ErrorString("pivot: column definition query expects 2 result columns")
}
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
@@ -83,7 +84,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err
}
if stmt.ColumnCount() != 1 {
return nil, errors.New("pivot: cell query expects 1 result columns")
return nil, util.ErrorString("pivot: cell query expects 1 result columns")
}
if stmt.BindCount() != len(table.keys)+1 {
return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1)

View File

@@ -8,12 +8,12 @@ package statement
import (
"encoding/json"
"errors"
"strconv"
"strings"
"unsafe"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the statement virtual table.
@@ -29,7 +29,7 @@ type table struct {
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, errors.New("statement: wrong number of arguments")
return nil, util.ErrorString("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]

166
ext/uuid/uuid.go Normal file
View File

@@ -0,0 +1,166 @@
// Package uuid provides functions to generate RFC 4122 UUIDs.
//
// https://sqlite.org/src/file/ext/misc/uuid.c
package uuid
import (
"bytes"
"fmt"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the SQL functions:
//
// uuid([version], [domain/namespace], [id/data])
//
// Generates a UUID as a string.
//
// uuid_str(u)
//
// Converts a UUID into a well-formed UUID string.
//
// uuid_blob(u)
//
// Converts a UUID into a 16-byte blob.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid_str", 1, flags, toString)
db.CreateFunction("uuid_blob", 1, flags, toBLOB)
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
var (
ver int
err error
u uuid.UUID
)
if len(arg) > 0 {
ver = arg[0].Int()
} else {
ver = 4
}
switch ver {
case 1:
u, err = uuid.NewUUID()
case 4:
u, err = uuid.NewRandom()
case 6:
u, err = uuid.NewV6()
case 7:
u, err = uuid.NewV7()
case 2:
var domain uuid.Domain
if len(arg) > 1 {
domain = uuid.Domain(arg[1].Int64())
if domain == 0 {
if txt := arg[1].RawText(); len(txt) > 0 {
switch txt[0] | 0x20 {
case 'g': // group
domain = 1
case 'o': // org
domain = 2
}
}
}
}
if len(arg) > 2 {
id := uint32(arg[2].Int64())
u, err = uuid.NewDCESecurity(domain, id)
} else if domain == uuid.Person {
u, err = uuid.NewDCEPerson()
} else if domain == uuid.Group {
u, err = uuid.NewDCEGroup()
} else {
err = util.ErrorString("missing id")
}
case 3, 5:
if len(arg) < 2 {
err = util.ErrorString("missing data")
break
}
ns, err := fromValue(arg[1])
if err != nil {
space := arg[1].RawText()
switch {
case bytes.EqualFold(space, []byte("url")):
ns = uuid.NameSpaceURL
case bytes.EqualFold(space, []byte("oid")):
ns = uuid.NameSpaceOID
case bytes.EqualFold(space, []byte("dns")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("fqdn")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("x500")):
ns = uuid.NameSpaceX500
default:
ctx.ResultError(err)
return
}
}
if ver == 3 {
u = uuid.NewMD5(ns, arg[2].RawBlob())
} else {
u = uuid.NewSHA1(ns, arg[2].RawBlob())
}
default:
err = fmt.Errorf("invalid version: %d", ver)
}
if err != nil {
ctx.ResultError(fmt.Errorf("uuid: %w", err))
} else {
ctx.ResultText(u.String())
}
}
func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) {
switch t := arg.Type(); t {
case sqlite3.TEXT:
u, err = uuid.ParseBytes(arg.RawText())
if err != nil {
err = fmt.Errorf("uuid: %w", err)
}
case sqlite3.BLOB:
blob := arg.RawBlob()
if len := len(blob); len != 16 {
err = fmt.Errorf("uuid: invalid BLOB length: %d", len)
} else {
copy(u[:], blob)
}
default:
err = fmt.Errorf("uuid: invalid type: %v", t)
}
return u, err
}
func toBLOB(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultBlob(u[:])
}
}
func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultText(u.String())
}
}

179
ext/uuid/uuid_test.go Normal file
View File

@@ -0,0 +1,179 @@
package uuid
import (
"testing"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
func Test_generate(t *testing.T) {
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
var u uuid.UUID
// Version 4, SQLite compatible
err = db.QueryRow(`SELECT uuid()`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 4 {
t.Errorf("got %d, want 4", got)
}
// Invalid version
err = db.QueryRow(`SELECT uuid(8)`).Scan(&u)
if err == nil {
t.Error("want error")
}
// Custom version, no arguments
for _, want := range []uuid.Version{1, 2, 4, 6, 7} {
err = db.QueryRow(`SELECT uuid(?)`, want).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != want {
t.Errorf("got %d, want %d", got, want)
}
}
// Version 2, custom arguments
err = db.QueryRow(`SELECT uuid(2, 4)`).Scan(&u)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT uuid(2, 'group')`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 2 {
t.Errorf("got %d, want 2", got)
}
if got := u.Domain(); got != uuid.Group {
t.Errorf("got %d, want 1", got)
}
dce := []struct {
out uuid.Domain
in any
id uint32
}{
{uuid.Person, "user", 42},
{uuid.Group, "group", 42},
{uuid.Org, "org", 42},
{uuid.Person, 0, 42},
{uuid.Group, 1, 42},
{uuid.Org, 2, 42},
{3, 3, 42},
}
for _, tt := range dce {
err = db.QueryRow(`SELECT uuid(2, ?, ?)`, tt.in, tt.id).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 2 {
t.Errorf("got %d, want 2", got)
}
if got := u.Domain(); got != tt.out {
t.Errorf("got %d, want %d", got, tt.out)
}
if got := u.ID(); got != tt.id {
t.Errorf("got %d, want %d", got, tt.id)
}
}
// Versions 3 and 5
err = db.QueryRow(`SELECT uuid(3)`).Scan(&u)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT uuid(3, 0, '')`).Scan(&u)
if err == nil {
t.Error("want error")
}
hash := []struct {
ver uuid.Version
ns any
data string
u uuid.UUID
}{
{3, "oid", "2.999", uuid.MustParse("31cb1efa-18c4-3d19-89ba-df6a74ddbd1d")},
{3, "dns", "www.example.com", uuid.MustParse("5df41881-3aed-3515-88a7-2f4a814cf09e")},
{3, "fqdn", "www.example.com", uuid.MustParse("5df41881-3aed-3515-88a7-2f4a814cf09e")},
{3, "url", "https://www.example.com/", uuid.MustParse("7fed185f-0864-319f-875b-a3d5458e30ac")},
{3, "x500", "CN=Test User 1, O=Example Organization, ST=California, C=US", uuid.MustParse("addf5e97-9287-3834-abfd-7edcbe7db56f")},
{3, "url", "https://www.php.net", uuid.MustParse("3f703955-aaba-3e70-a3cb-baff6aa3b28f")},
{5, "url", "https://www.php.net", uuid.MustParse("a8f6ae40-d8a7-58f0-be05-a22f94eca9ec")},
}
for _, tt := range hash {
err = db.QueryRow(`SELECT uuid(?, ?, ?)`, tt.ver, tt.ns, tt.data).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != tt.u {
t.Errorf("got %v, want %v", u, tt.u)
}
}
}
func Test_convert(t *testing.T) {
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
var u uuid.UUID
lits := []string{
"'6ba7b8119dad11d180b400c04fd430c8'",
"'6ba7b811-9dad-11d1-80b4-00c04fd430c8'",
"'{6ba7b811-9dad-11d1-80b4-00c04fd430c8}'",
"X'6ba7b8119dad11d180b400c04fd430c8'",
}
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_str(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != uuid.NameSpaceURL {
t.Errorf("got %v, want %v", u, uuid.NameSpaceURL)
}
}
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_blob(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != uuid.NameSpaceURL {
t.Errorf("got %v, want %v", u, uuid.NameSpaceURL)
}
}
err = db.QueryRow(`SELECT uuid_str(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_blob(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
}

2
go.mod
View File

@@ -17,4 +17,6 @@ require (
lukechampine.com/adiantum v1.1.1
)
require github.com/google/uuid v1.6.0
retract v0.4.0 // tagged from the wrong branch

2
go.sum
View File

@@ -1,5 +1,7 @@
github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA=
github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=

View File

@@ -19,7 +19,6 @@ curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/ieee754.
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/uuid.c"
cd ~-
cd ../vfs/tests/mptest/testdata/

View File

@@ -8,7 +8,6 @@
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
// Bindings
#include "column.c"
#include "func.c"
@@ -28,6 +27,5 @@ __attribute__((constructor)) void init() {
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}