Compare commits

..

12 Commits

Author SHA1 Message Date
Nuno Cruces
b5f746aadf Automatically load extensions. (#115) 2024-07-08 12:06:57 +01:00
dependabot[bot]
fff8b1c74f Bump golang.org/x/crypto from 0.24.0 to 0.25.0 (#116)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.24.0 to 0.25.0.
- [Commits](https://github.com/golang/crypto/compare/v0.24.0...v0.25.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-06 00:33:50 +01:00
Nuno Cruces
d27da3f390 Fix flaky test. 2024-07-05 00:49:22 +01:00
Nuno Cruces
a1fae26b66 Regular expression extension. (#114) 2024-07-05 00:12:26 +01:00
Nuno Cruces
806cc6677d Updated dependencies. 2024-07-04 19:38:26 +01:00
Nuno Cruces
da6e4d8b86 UUID extension (#113) 2024-07-04 15:28:49 +01:00
Nuno Cruces
72f8ad0f14 Toolchain. 2024-07-03 15:02:58 +01:00
Nuno Cruces
5a4c7a58c4 Refactor CREATE parser. (#111) 2024-07-03 14:06:07 +01:00
Nuno Cruces
90f7e502be Tweaks. 2024-07-02 15:42:20 +01:00
Nuno Cruces
c0b289d000 More BSDs. 2024-06-26 14:56:36 +01:00
Nuno Cruces
a84d905d8c Fix go:linkname for mmap (#107) 2024-06-25 10:31:11 +01:00
Nuno Cruces
aa7edb1848 Tests. 2024-06-21 16:23:56 +01:00
64 changed files with 1149 additions and 370 deletions

View File

@@ -83,6 +83,18 @@ jobs:
run: go test -v ./...
test-bsd:
strategy:
matrix:
os:
- name: freebsd
version: '14.0'
flags: '-test.v'
- name: openbsd
version: '7.5'
flags: '-test.v -test.short'
- name: netbsd
version: '10.0'
flags: '-test.v -test.short'
runs-on: ubuntu-latest
needs: test
@@ -96,15 +108,15 @@ jobs:
- name: Build
env:
GOOS: freebsd
TESTFLAGS: '-test.v'
GOOS: ${{ matrix.os.name }}
TESTFLAGS: ${{ matrix.os.flags }}
run: .github/workflows/build-test.sh
- name: Test
uses: cross-platform-actions/action@v0.24.0
with:
operating_system: freebsd
version: '14.0'
operating_system: ${{ matrix.os.name }}
version: ${{ matrix.os.version }}
shell: bash
run: . ./test.sh
sync_files: runner-to-vm

View File

@@ -45,12 +45,16 @@ Go, wazero and [`x/sys`](https://pkg.go.dev/golang.org/x/sys) are the _only_ run
reads data [line-by-line](https://github.com/asg017/sqlite-lines).
- [`github.com/ncruces/go-sqlite3/ext/pivot`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/pivot)
creates [pivot tables](https://github.com/jakethaw/pivot_vtab).
- [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp)
provides regular expression functions.
- [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement)
creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab).
- [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)
provides [statistics](https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html) functions.
- [`github.com/ncruces/go-sqlite3/ext/unicode`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/unicode)
provides [Unicode aware](https://sqlite.org/src/dir/ext/icu) functions.
- [`github.com/ncruces/go-sqlite3/ext/uuid`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/uuid)
generates [UUIDs](https://en.wikipedia.org/wiki/Universally_unique_identifier).
- [`github.com/ncruces/go-sqlite3/ext/zorder`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/zorder)
maps multidimensional data to one dimension.
- [`github.com/ncruces/go-sqlite3/vfs/adiantum`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/adiantum)
@@ -91,7 +95,8 @@ It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and
Every commit is [tested](.github/workflows/test.yml) on
Linux (amd64/arm64/386/riscv64/s390x), macOS (amd64/arm64),
Windows (amd64), FreeBSD (amd64), illumos (amd64), and Solaris (amd64).
Windows (amd64), FreeBSD (amd64), OpenBSD (amd64), NetBSD (amd64),
illumos (amd64), and Solaris (amd64).
The Go VFS is tested by running SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c).

View File

@@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
}
if err != nil {
return nil, err
}

View File

@@ -7,7 +7,7 @@ ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_117/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -O2 \
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -flto -g0 -O2 \
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \

Binary file not shown.

View File

@@ -15,8 +15,8 @@ import (
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "array", nil,
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "array", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err

View File

@@ -15,10 +15,7 @@ import (
)
func Example_driver() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
log.Fatal(err)
}
@@ -53,14 +50,14 @@ func Example_driver() {
}
func Example() {
sqlite3.AutoExtension(array.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
array.Register(db)
stmt, _, err := db.Prepare(`
SELECT name
FROM pragma_function_list
@@ -91,10 +88,7 @@ func Example() {
func Test_cursor_Column(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
t.Fatal(err)
}
@@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
}
defer db.Close()
array.Register(db)
err = array.Register(db)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT * FROM array()`)
if err == nil {

View File

@@ -29,10 +29,11 @@ import (
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("readblob", 6, 0, readblob)
db.CreateFunction("writeblob", 6, 0, writeblob)
db.CreateFunction("openblob", -1, 0, openblob)
func Register(db *sqlite3.Conn) error {
return errors.Join(
db.CreateFunction("readblob", 6, 0, readblob),
db.CreateFunction("writeblob", 6, 0, writeblob),
db.CreateFunction("openblob", -1, 0, openblob))
}
// OpenCallback is the type for the openblob callback.

View File

@@ -18,10 +18,7 @@ import (
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blobio.Register(conn)
return nil
})
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)
if err != nil {
log.Fatal(err)
@@ -60,6 +57,11 @@ func Example() {
// Hello BLOB!
}
func init() {
sqlite3.AutoExtension(blobio.Register)
sqlite3.AutoExtension(array.Register)
}
func Test_readblob(t *testing.T) {
t.Parallel()
@@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT readblob()`)
if err == nil {
t.Fatal("want error")
@@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT openblob()`)
if err == nil {
t.Fatal("want error")

View File

@@ -7,7 +7,6 @@
package bloom
import (
"errors"
"fmt"
"io"
"math"
@@ -15,13 +14,14 @@ import (
"github.com/dchest/siphash"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the bloom_filter virtual table:
//
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "bloom_filter", create, connect)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
}
type bloom struct {
@@ -47,7 +47,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if nelem <= 0 {
return nil, errors.New("bloom: number of elements in filter must be positive")
return nil, util.ErrorString("bloom: number of elements in filter must be positive")
}
} else {
nelem = 100
@@ -59,7 +59,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if t.prob <= 0 || t.prob >= 1 {
return nil, errors.New("bloom: probability must be in the range (0,1)")
return nil, util.ErrorString("bloom: probability must be in the range (0,1)")
}
} else {
t.prob = 0.01
@@ -71,7 +71,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom,
return nil, err
}
if t.hashes <= 0 {
return nil, errors.New("bloom: number of hash functions must be positive")
return nil, util.ErrorString("bloom: number of hash functions must be positive")
}
} else {
t.hashes = max(1, numHashes(t.prob))
@@ -171,7 +171,7 @@ func (t *bloom) Integrity(schema, table string, flags int) error {
}
defer load.Close()
err = errors.New("bloom: invalid parameters")
err = util.ErrorString("bloom: invalid parameters")
if !load.Step() {
return err
}
@@ -213,9 +213,9 @@ func (b *bloom) BestIndex(idx *sqlite3.IndexInfo) error {
func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
if arg[0].Type() != sqlite3.NULL {
if len(arg) == 1 {
return 0, errors.New("bloom: elements cannot be deleted")
return 0, util.ErrorString("bloom: elements cannot be deleted")
}
return 0, errors.New("bloom: elements cannot be updated")
return 0, util.ErrorString("bloom: elements cannot be updated")
}
blob := arg[2].RawBlob()
@@ -262,8 +262,8 @@ func (b *bloom) Open() (sqlite3.VTabCursor, error) {
type cursor struct {
*bloom
eof bool
arg *sqlite3.Value
eof bool
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {

View File

@@ -12,6 +12,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func init() {
sqlite3.AutoExtension(bloom.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
bloom.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
@@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
}
defer db.Close()
bloom.Register(db)
query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
if err != nil {
t.Fatal(err)

View File

@@ -9,7 +9,6 @@ package csv
import (
"bufio"
"encoding/csv"
"errors"
"fmt"
"io"
"io/fs"
@@ -17,19 +16,20 @@ import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
@@ -73,7 +73,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
}
if (filename == "") == (data == "") {
return nil, errors.New(`csv: must specify either "filename" or "data" but not both`)
return nil, util.ErrorString(`csv: must specify either "filename" or "data" but not both`)
}
table := &table{
@@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
return table, nil
}
sqlite3.CreateModule(db, "csv", declare, declare)
return sqlite3.CreateModule(db, "csv", declare, declare)
}
type table struct {

View File

@@ -18,7 +18,10 @@ func Example() {
}
defer db.Close()
csv.Register(db)
err = csv.Register(db)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE eurofxref USING csv(
@@ -51,6 +54,10 @@ func Example() {
// On Twosday, 1€ = $1.1342
}
func init() {
sqlite3.AutoExtension(csv.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
csv.Register(db)
const data = `
# Comment
"Rob" "Pike" rob
@@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
}
defer db.Close()
csv.Register(db)
const data = "01\n0.10\ne"
err = db.Exec(`
CREATE VIRTUAL TABLE temp.nums USING csv(
@@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
csv.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
if err == nil {
t.Fatal("want error")

View File

@@ -1,7 +1,6 @@
package csv
import (
_ "embed"
"strings"
"github.com/ncruces/go-sqlite3/util/vtabutil"
@@ -22,12 +21,10 @@ func getColumnAffinities(schema string) ([]affinity, error) {
if err != nil {
return nil, err
}
defer tab.Close()
types := make([]affinity, tab.NumColumns())
for i := range types {
col := tab.Column(i)
types[i] = getAffinity(col.Type())
types := make([]affinity, len(tab.Columns))
for i, col := range tab.Columns {
types[i] = getAffinity(col.Type)
}
return types, nil
}

View File

@@ -1,9 +1,6 @@
package csv
import (
_ "embed"
"testing"
)
import "testing"
func Test_getAffinity(t *testing.T) {
tests := []struct {

View File

@@ -14,24 +14,26 @@ import (
// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}
// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
var err error
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})
return errors.Join(err,
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
}))
}
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -17,10 +17,7 @@ import (
func Test_lsmode(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.Register(c)
return nil
})
db, err := driver.Open(":memory:", fileio.Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) {
}
defer db.Close()
fileio.Register(db)
err = fileio.Register(db)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT name FROM fsdir()`)
if err == nil {

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -16,10 +15,7 @@ import (
func Test_writefile(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -21,47 +21,60 @@ package hash
import (
"crypto"
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers cryptographic hash functions for a database connection.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var errs util.ErrorJoiner
if crypto.MD4.Available() {
db.CreateFunction("md4", 1, flags, md4Func)
errs.Join(
db.CreateFunction("md4", 1, flags, md4Func))
}
if crypto.MD5.Available() {
db.CreateFunction("md5", 1, flags, md5Func)
errs.Join(
db.CreateFunction("md5", 1, flags, md5Func))
}
if crypto.SHA1.Available() {
db.CreateFunction("sha1", 1, flags, sha1Func)
errs.Join(
db.CreateFunction("sha1", 1, flags, sha1Func))
}
if crypto.SHA3_512.Available() {
db.CreateFunction("sha3", 1, flags, sha3Func)
db.CreateFunction("sha3", 2, flags, sha3Func)
errs.Join(
db.CreateFunction("sha3", 1, flags, sha3Func),
db.CreateFunction("sha3", 2, flags, sha3Func))
}
if crypto.SHA256.Available() {
db.CreateFunction("sha224", 1, flags, sha224Func)
db.CreateFunction("sha256", 1, flags, sha256Func)
db.CreateFunction("sha256", 2, flags, sha256Func)
errs.Join(
db.CreateFunction("sha224", 1, flags, sha224Func),
db.CreateFunction("sha256", 1, flags, sha256Func),
db.CreateFunction("sha256", 2, flags, sha256Func))
}
if crypto.SHA512.Available() {
db.CreateFunction("sha384", 1, flags, sha384Func)
db.CreateFunction("sha512", 1, flags, sha512Func)
db.CreateFunction("sha512", 2, flags, sha512Func)
errs.Join(
db.CreateFunction("sha384", 1, flags, sha384Func),
db.CreateFunction("sha512", 1, flags, sha512Func),
db.CreateFunction("sha512", 2, flags, sha512Func))
}
if crypto.BLAKE2s_256.Available() {
db.CreateFunction("blake2s", 1, flags, blake2sFunc)
errs.Join(
db.CreateFunction("blake2s", 1, flags, blake2sFunc))
}
if crypto.BLAKE2b_512.Available() {
db.CreateFunction("blake2b", 1, flags, blake2bFunc)
db.CreateFunction("blake2b", 2, flags, blake2bFunc)
errs.Join(
db.CreateFunction("blake2b", 1, flags, blake2bFunc),
db.CreateFunction("blake2b", 2, flags, blake2bFunc))
}
if crypto.RIPEMD160.Available() {
db.CreateFunction("ripemd160", 1, flags, ripemd160Func)
errs.Join(
db.CreateFunction("ripemd160", 1, flags, ripemd160Func))
}
return errors.Join(errs...)
}
func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -7,7 +7,6 @@ import (
_ "crypto/sha512"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -53,10 +52,7 @@ func TestRegister(t *testing.T) {
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
}
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -13,6 +13,7 @@ package lines
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/fs"
@@ -25,27 +26,28 @@ import (
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
sqlite3.CreateModule(db, "lines", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
return lines{}, err
})
sqlite3.CreateModule(db, "lines_read", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return lines{fsys}, err
})
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
return errors.Join(
sqlite3.CreateModule(db, "lines", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
return lines{}, err
}),
sqlite3.CreateModule(db, "lines_read", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return lines{fsys}, err
}))
}
type lines struct {

View File

@@ -18,10 +18,7 @@ import (
)
func Example() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -70,10 +67,7 @@ func Example() {
func Test_lines(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -103,10 +97,7 @@ func Test_lines(t *testing.T) {
func Test_lines_error(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -130,10 +121,7 @@ func Test_lines_error(t *testing.T) {
func Test_lines_read(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -164,10 +152,7 @@ func Test_lines_read(t *testing.T) {
func Test_lines_test(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}

View File

@@ -9,11 +9,12 @@ import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the pivot virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "pivot", declare, declare)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "pivot", declare, declare)
}
type table struct {
@@ -65,7 +66,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err
}
if stmt.ColumnCount() != 2 {
return nil, errors.New("pivot: column definition query expects 2 result columns")
return nil, util.ErrorString("pivot: column definition query expects 2 result columns")
}
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
@@ -83,7 +84,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err
}
if stmt.ColumnCount() != 1 {
return nil, errors.New("pivot: cell query expects 1 result columns")
return nil, util.ErrorString("pivot: cell query expects 1 result columns")
}
if stmt.BindCount() != len(table.keys)+1 {
return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1)

View File

@@ -14,14 +14,14 @@ import (
// https://antonz.org/sqlite-pivot-table/
func Example() {
sqlite3.AutoExtension(pivot.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE sales(product TEXT, year INT, income DECIMAL);
INSERT INTO sales(product, year, income) VALUES
@@ -83,6 +83,10 @@ func Example() {
// gamma 80 75 78 80
}
func init() {
sqlite3.AutoExtension(pivot.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -92,8 +96,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE r AS
SELECT 1 id UNION SELECT 2 UNION SELECT 3;
@@ -153,8 +155,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`)
if err == nil {
t.Fatal("want error")

78
ext/regexp/regexp.go Normal file
View File

@@ -0,0 +1,78 @@
// Package regexp provides additional regular expression functions.
//
// It provides the following Unicode aware functions:
// - regexp_like(),
// - regexp_substr(),
// - regexp_replace(),
// - and a REGEXP operator.
//
// The implementation uses Go [regexp/syntax] for regular expressions.
//
// https://github.com/nalgeon/sqlean/blob/main/docs/regexp.md
package regexp
import (
"errors"
"regexp"
"github.com/ncruces/go-sqlite3"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("regexp_like", 2, flags, regexLike),
db.CreateFunction("regexp_substr", 2, flags, regexSubstr),
db.CreateFunction("regexp_replace", 3, flags, regexReplace))
}
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
re = r
ctx.SetAuxData(0, r)
}
return re, nil
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 0, arg[0].Text())
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultBool(re.Match(arg[1].RawText()))
}
}
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultBool(re.Match(arg[0].RawText()))
}
}
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultRawText(re.Find(arg[0].RawText()))
}
}
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultRawText(re.ReplaceAll(arg[0].RawText(), arg[2].RawText()))
}
}

68
ext/regexp/regexp_test.go Normal file
View File

@@ -0,0 +1,68 @@
package regexp
import (
"testing"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func TestRegister(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tests := []struct {
test string
want string
}{
{`'Hello' REGEXP 'elo'`, "0"},
{`'Hello' REGEXP 'ell'`, "1"},
{`'Hello' REGEXP 'el.'`, "1"},
{`regexp_like('Hello', 'elo')`, "0"},
{`regexp_like('Hello', 'ell')`, "1"},
{`regexp_like('Hello', 'el.')`, "1"},
{`regexp_substr('Hello', 'el.')`, "ell"},
{`regexp_replace('Hello', 'llo', 'll')`, "Hell"},
}
for _, tt := range tests {
var got string
err := db.QueryRow(`SELECT ` + tt.test).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
}
}
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
tests := []string{
`'' REGEXP ?`,
`regexp_like('', ?)`,
`regexp_substr('', ?)`,
`regexp_replace('', ?, '')`,
}
for _, tt := range tests {
err := db.QueryRow(`SELECT `+tt, `\`).Scan(nil)
if err == nil {
t.Fatal("want error")
}
}
}

View File

@@ -8,17 +8,17 @@ package statement
import (
"encoding/json"
"errors"
"strconv"
"strings"
"unsafe"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "statement", declare, declare)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "statement", declare, declare)
}
type table struct {
@@ -29,7 +29,7 @@ type table struct {
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, errors.New("statement: wrong number of arguments")
return nil, util.ErrorString("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]

View File

@@ -12,14 +12,14 @@ import (
)
func Example() {
sqlite3.AutoExtension(statement.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE split_date USING statement((
SELECT
@@ -48,6 +48,10 @@ func Example() {
// Twosday was 2022-2-22
}
func init() {
sqlite3.AutoExtension(statement.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -57,8 +61,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE arguments USING statement((SELECT ? AS a, ? AS b, ? AS c))
`)
@@ -107,8 +109,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
statement.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement()`)
if err == nil {
t.Fatal("want error")

View File

@@ -5,7 +5,6 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
@@ -18,8 +17,6 @@ func TestRegister_boolean(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)

View File

@@ -6,7 +6,6 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
@@ -19,8 +18,6 @@ func TestRegister_percentile(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)

View File

@@ -44,33 +44,38 @@
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
db.CreateWindowFunction("median", 1, flags, newPercentile(median))
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont))
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc))
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
return errors.Join(
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)),
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)),
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)),
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)),
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)),
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)),
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
db.CreateWindowFunction("median", 1, flags, newPercentile(median)),
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
}
const (

View File

@@ -10,6 +10,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func init() {
sqlite3.AutoExtension(stats.Register)
}
func TestRegister_variance(t *testing.T) {
t.Parallel()
@@ -19,8 +23,6 @@ func TestRegister_variance(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)
@@ -88,8 +90,6 @@ func TestRegister_covariance(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (y, x)`)
if err != nil {
t.Fatal(err)
@@ -217,8 +217,6 @@ func Benchmark_variance(b *testing.B) {
}
defer db.Close()
stats.Register(db)
stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)

View File

@@ -18,6 +18,7 @@ package unicode
import (
"bytes"
"errors"
"regexp"
"strings"
"unicode/utf8"
@@ -30,29 +31,29 @@ import (
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("like", 2, flags, like),
db.CreateFunction("like", 3, flags, like),
db.CreateFunction("upper", 1, flags, upper),
db.CreateFunction("upper", 2, flags, upper),
db.CreateFunction("lower", 1, flags, lower),
db.CreateFunction("lower", 2, flags, lower),
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
db.CreateFunction("like", 2, flags, like)
db.CreateFunction("like", 3, flags, like)
db.CreateFunction("upper", 1, flags, upper)
db.CreateFunction("upper", 2, flags, upper)
db.CreateFunction("lower", 1, flags, lower)
db.CreateFunction("lower", 2, flags, lower)
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
})
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
}))
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
@@ -111,7 +112,7 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
return
}
re = r
ctx.SetAuxData(0, re)
ctx.SetAuxData(0, r)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}

168
ext/uuid/uuid.go Normal file
View File

@@ -0,0 +1,168 @@
// Package uuid provides functions to generate RFC 4122 UUIDs.
//
// https://sqlite.org/src/file/ext/misc/uuid.c
package uuid
import (
"bytes"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the SQL functions:
//
// uuid([version], [domain/namespace], [id/data])
//
// Generates a UUID as a string.
//
// uuid_str(u)
//
// Converts a UUID into a well-formed UUID string.
//
// uuid_blob(u)
//
// Converts a UUID into a 16-byte blob.
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
var (
ver int
err error
u uuid.UUID
)
if len(arg) > 0 {
ver = arg[0].Int()
} else {
ver = 4
}
switch ver {
case 1:
u, err = uuid.NewUUID()
case 4:
u, err = uuid.NewRandom()
case 6:
u, err = uuid.NewV6()
case 7:
u, err = uuid.NewV7()
case 2:
var domain uuid.Domain
if len(arg) > 1 {
domain = uuid.Domain(arg[1].Int64())
if domain == 0 {
if txt := arg[1].RawText(); len(txt) > 0 {
switch txt[0] | 0x20 {
case 'g': // group
domain = 1
case 'o': // org
domain = 2
}
}
}
}
if len(arg) > 2 {
id := uint32(arg[2].Int64())
u, err = uuid.NewDCESecurity(domain, id)
} else if domain == uuid.Person {
u, err = uuid.NewDCEPerson()
} else if domain == uuid.Group {
u, err = uuid.NewDCEGroup()
} else {
err = util.ErrorString("missing id")
}
case 3, 5:
if len(arg) < 2 {
err = util.ErrorString("missing data")
break
}
ns, err := fromValue(arg[1])
if err != nil {
space := arg[1].RawText()
switch {
case bytes.EqualFold(space, []byte("url")):
ns = uuid.NameSpaceURL
case bytes.EqualFold(space, []byte("oid")):
ns = uuid.NameSpaceOID
case bytes.EqualFold(space, []byte("dns")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("fqdn")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("x500")):
ns = uuid.NameSpaceX500
default:
ctx.ResultError(err)
return
}
}
if ver == 3 {
u = uuid.NewMD5(ns, arg[2].RawBlob())
} else {
u = uuid.NewSHA1(ns, arg[2].RawBlob())
}
default:
err = fmt.Errorf("invalid version: %d", ver)
}
if err != nil {
ctx.ResultError(fmt.Errorf("uuid: %w", err))
} else {
ctx.ResultText(u.String())
}
}
func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) {
switch t := arg.Type(); t {
case sqlite3.TEXT:
u, err = uuid.ParseBytes(arg.RawText())
if err != nil {
err = fmt.Errorf("uuid: %w", err)
}
case sqlite3.BLOB:
blob := arg.RawBlob()
if len := len(blob); len != 16 {
err = fmt.Errorf("uuid: invalid BLOB length: %d", len)
} else {
copy(u[:], blob)
}
default:
err = fmt.Errorf("uuid: invalid type: %v", t)
}
return u, err
}
func toBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultBlob(u[:])
}
}
func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
} else {
ctx.ResultText(u.String())
}
}

177
ext/uuid/uuid_test.go Normal file
View File

@@ -0,0 +1,177 @@
package uuid
import (
"testing"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func Test_generate(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var u uuid.UUID
// Version 4, SQLite compatible
err = db.QueryRow(`SELECT uuid()`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 4 {
t.Errorf("got %d, want 4", got)
}
// Invalid version
err = db.QueryRow(`SELECT uuid(8)`).Scan(&u)
if err == nil {
t.Error("want error")
}
// Custom version, no arguments
for _, want := range []uuid.Version{1, 2, 4, 6, 7} {
err = db.QueryRow(`SELECT uuid(?)`, want).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != want {
t.Errorf("got %d, want %d", got, want)
}
}
// Version 2, custom arguments
err = db.QueryRow(`SELECT uuid(2, 4)`).Scan(&u)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT uuid(2, 'group')`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 2 {
t.Errorf("got %d, want 2", got)
}
if got := u.Domain(); got != uuid.Group {
t.Errorf("got %d, want 1", got)
}
dce := []struct {
out uuid.Domain
in any
id uint32
}{
{uuid.Person, "user", 42},
{uuid.Group, "group", 42},
{uuid.Org, "org", 42},
{uuid.Person, 0, 42},
{uuid.Group, 1, 42},
{uuid.Org, 2, 42},
{3, 3, 42},
}
for _, tt := range dce {
err = db.QueryRow(`SELECT uuid(2, ?, ?)`, tt.in, tt.id).Scan(&u)
if err != nil {
t.Fatal(err)
}
if got := u.Version(); got != 2 {
t.Errorf("got %d, want 2", got)
}
if got := u.Domain(); got != tt.out {
t.Errorf("got %d, want %d", got, tt.out)
}
if got := u.ID(); got != tt.id {
t.Errorf("got %d, want %d", got, tt.id)
}
}
// Versions 3 and 5
err = db.QueryRow(`SELECT uuid(3)`).Scan(&u)
if err == nil {
t.Error("want error")
}
err = db.QueryRow(`SELECT uuid(3, 0, '')`).Scan(&u)
if err == nil {
t.Error("want error")
}
hash := []struct {
ver uuid.Version
ns any
data string
u uuid.UUID
}{
{3, "oid", "2.999", uuid.MustParse("31cb1efa-18c4-3d19-89ba-df6a74ddbd1d")},
{3, "dns", "www.example.com", uuid.MustParse("5df41881-3aed-3515-88a7-2f4a814cf09e")},
{3, "fqdn", "www.example.com", uuid.MustParse("5df41881-3aed-3515-88a7-2f4a814cf09e")},
{3, "url", "https://www.example.com/", uuid.MustParse("7fed185f-0864-319f-875b-a3d5458e30ac")},
{3, "x500", "CN=Test User 1, O=Example Organization, ST=California, C=US", uuid.MustParse("addf5e97-9287-3834-abfd-7edcbe7db56f")},
{3, "url", "https://www.php.net", uuid.MustParse("3f703955-aaba-3e70-a3cb-baff6aa3b28f")},
{5, "url", "https://www.php.net", uuid.MustParse("a8f6ae40-d8a7-58f0-be05-a22f94eca9ec")},
}
for _, tt := range hash {
err = db.QueryRow(`SELECT uuid(?, ?, ?)`, tt.ver, tt.ns, tt.data).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != tt.u {
t.Errorf("got %v, want %v", u, tt.u)
}
}
}
func Test_convert(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var u uuid.UUID
lits := []string{
"'6ba7b8119dad11d180b400c04fd430c8'",
"'6ba7b811-9dad-11d1-80b4-00c04fd430c8'",
"'{6ba7b811-9dad-11d1-80b4-00c04fd430c8}'",
"X'6ba7b8119dad11d180b400c04fd430c8'",
}
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_str(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != uuid.NameSpaceURL {
t.Errorf("got %v, want %v", u, uuid.NameSpaceURL)
}
}
for _, tt := range lits {
err = db.QueryRow(`SELECT uuid_blob(` + tt + `)`).Scan(&u)
if err != nil {
t.Fatal(err)
}
if u != uuid.NameSpaceURL {
t.Errorf("got %v, want %v", u, uuid.NameSpaceURL)
}
}
err = db.QueryRow(`SELECT uuid_str(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
err = db.QueryRow(`SELECT uuid_blob(X'cafe')`).Scan(&u)
if err == nil {
t.Fatal("want error")
}
}

View File

@@ -4,15 +4,18 @@
package zorder
import (
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the zorder and unzorder SQL functions.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("zorder", -1, flags, zorder)
db.CreateFunction("unzorder", 3, flags, unzorder)
return errors.Join(
db.CreateFunction("zorder", -1, flags, zorder),
db.CreateFunction("unzorder", 3, flags, unzorder))
}
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -3,7 +3,6 @@ package zorder_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/zorder"
@@ -13,10 +12,7 @@ import (
func TestRegister_zorder(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -60,10 +56,7 @@ func TestRegister_zorder(t *testing.T) {
func TestRegister_unzorder(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -90,10 +83,7 @@ func TestRegister_unzorder(t *testing.T) {
func TestRegister_error(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -31,8 +31,9 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
//
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
func (c Conn) AnyCollationNeeded() error {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
return c.error(r)
}
// CreateCollation defines a new collating sequence.

View File

@@ -130,8 +130,8 @@ func ExampleContext_SetAuxData() {
ctx.ResultError(err)
return
}
ctx.SetAuxData(0, r)
re = r
ctx.SetAuxData(0, r)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
})

8
go.mod
View File

@@ -2,17 +2,21 @@ module github.com/ncruces/go-sqlite3
go 1.21
toolchain go1.22.5
require (
github.com/dchest/siphash v1.2.3
github.com/ncruces/julianday v1.0.0
github.com/ncruces/sort v0.1.2
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.7.3
golang.org/x/crypto v0.24.0
golang.org/x/crypto v0.25.0
golang.org/x/sync v0.7.0
golang.org/x/sys v0.21.0
golang.org/x/sys v0.22.0
golang.org/x/text v0.16.0
lukechampine.com/adiantum v1.1.1
)
require github.com/google/uuid v1.6.0
retract v0.4.0 // tagged from the wrong branch

10
go.sum
View File

@@ -1,5 +1,7 @@
github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA=
github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
@@ -8,12 +10,12 @@ github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIw
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.7.3 h1:PBH5KVahrt3S2AHgEjKu4u+LlDbbk+nsGE3KLucy6Rw=
github.com/tetratelabs/wazero v1.7.3/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
lukechampine.com/adiantum v1.1.1 h1:4fp6gTxWCqpEbLy40ExiYDDED3oUNWx5cTqBCtPdZqA=

View File

@@ -3,9 +3,11 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.5/sqlite_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/sqlite_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/v1.5.6/sqlite_test.go"
curl -#L "https://github.com/glebarez/sqlite/raw/v1.11.0/sqlite_error_translator_test.go" > error_translator_test.go

View File

@@ -0,0 +1,48 @@
package gormlite
import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func TestErrorTranslator(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// This is the example object for testing the unique constraint error
type Article struct {
ArticleNumber string `gorm:"unique"`
}
db, err := gorm.Open(Open(InMemoryDSN), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
TranslateError: true})
if err != nil {
t.Errorf("Expected Open to succeed; got error: %v", err)
}
if db == nil {
t.Errorf("Expected db to be non-nil.")
}
err = db.AutoMigrate(&Article{})
if err != nil {
t.Errorf("Expected to migrate database models to succeed: %v", err)
}
err = db.Create(&Article{ArticleNumber: "A00000XX"}).Error
if err != nil {
t.Errorf("Expected first create to succeed: %v", err)
}
err = db.Create(&Article{ArticleNumber: "A00000XX"}).Error
if err == nil {
t.Errorf("Expected second create to fail.")
}
if err != gorm.ErrDuplicatedKey {
t.Errorf("Expected error from second create to be gorm.ErrDuplicatedKey: %v", err)
}
}

View File

@@ -2,8 +2,10 @@ module github.com/ncruces/go-sqlite3/gormlite
go 1.21
toolchain go1.22.5
require (
github.com/ncruces/go-sqlite3 v0.16.1
github.com/ncruces/go-sqlite3 v0.16.3
gorm.io/gorm v1.25.10
)
@@ -12,5 +14,5 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.3 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/sys v0.22.0 // indirect
)

View File

@@ -2,14 +2,14 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.16.1 h1:1wHv7s8y+fWK44UIliotJ42ZV41A5T0sjIAqGmnMrkc=
github.com/ncruces/go-sqlite3 v0.16.1/go.mod h1:feFXbBcbLtxNk6XWG1ROt8MS9+E45yCW3G8o4ixIqZ8=
github.com/ncruces/go-sqlite3 v0.16.3 h1:Ky0denOdmAGOoCE6lQlw6GCJNMD8gTikNWe8rpu+Gjc=
github.com/ncruces/go-sqlite3 v0.16.3/go.mod h1:sAU/vQwBmZ2hq5BlW/KTzqRFizL43bv2JQoBLgXhcMI=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.7.3 h1:PBH5KVahrt3S2AHgEjKu4u+LlDbbk+nsGE3KLucy6Rw=
github.com/tetratelabs/wazero v1.7.3/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=

View File

@@ -104,3 +104,13 @@ func ErrorCodeString(rc uint32) string {
}
return "sqlite3: unknown error"
}
type ErrorJoiner []error
func (j *ErrorJoiner) Join(errs ...error) {
for _, err := range errs {
if err != nil {
*j = append(*j, err)
}
}
}

View File

@@ -46,7 +46,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped
// Save the newly allocated region.
ptr := uint32(stack[0])
buf := View(mod, ptr, uint64(size))
addr := uintptr(unsafe.Pointer(&buf[0]))
addr := unsafe.Pointer(&buf[0])
s.regions = append(s.regions, &MappedRegion{
Ptr: ptr,
addr: addr,
@@ -56,7 +56,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped
}
type MappedRegion struct {
addr uintptr
addr unsafe.Pointer
Ptr uint32
size int32
used bool
@@ -76,23 +76,15 @@ func (r *MappedRegion) Unmap() error {
// We can't munmap the region, otherwise it could be remaped.
// Instead, convert it to a protected, private, anonymous mapping.
// If successful, it can be reused for a subsequent mmap.
_, err := mmap(r.addr, uintptr(r.size),
unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANON|unix.MAP_FIXED,
-1, 0)
_, err := unix.MmapPtr(-1, 0, r.addr, uintptr(r.size),
unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_FIXED|unix.MAP_ANON)
r.used = err != nil
return err
}
func (r *MappedRegion) mmap(f *os.File, offset int64, prot int) error {
_, err := mmap(r.addr, uintptr(r.size),
prot, unix.MAP_SHARED|unix.MAP_FIXED,
int(f.Fd()), offset)
_, err := unix.MmapPtr(int(f.Fd()), offset, r.addr, uintptr(r.size),
prot, unix.MAP_SHARED|unix.MAP_FIXED)
r.used = err == nil
return err
}
// We need the low level mmap for MAP_FIXED to work.
// Bind the syscall version hoping that it is more stable.
//go:linkname mmap syscall.mmap
func mmap(addr, length uintptr, prot, flag, fd int, pos int64) (*byte, error)

30
registry.go Normal file
View File

@@ -0,0 +1,30 @@
package sqlite3
import "sync"
var (
// +checklocks:extRegistryMtx
extRegistry []func(*Conn) error
extRegistryMtx sync.RWMutex
)
// AutoExtension causes the entryPoint function to be invoked
// for each new database connection that is created.
//
// https://sqlite.org/c3ref/auto_extension.html
func AutoExtension(entryPoint func(*Conn) error) {
extRegistryMtx.Lock()
defer extRegistryMtx.Unlock()
extRegistry = append(extRegistry, entryPoint)
}
func initExtensions(c *Conn) error {
extRegistryMtx.RLock()
defer extRegistryMtx.RUnlock()
for _, f := range extRegistry {
if err := f(c); err != nil {
return err
}
}
return nil
}

View File

@@ -19,7 +19,6 @@ curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/ieee754.
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/series.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.46.0/ext/misc/uuid.c"
cd ~-
cd ../vfs/tests/mptest/testdata/

View File

@@ -8,7 +8,6 @@
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
// Bindings
#include "column.c"
#include "func.c"
@@ -28,6 +27,5 @@ __attribute__((constructor)) void init() {
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}

View File

@@ -159,7 +159,7 @@ static int go_vtab_integrity_wrapper(sqlite3_vtab *pVTab, const char *zSchema,
return rc;
}
static int go_vtab_shadown_name_wrapper(const char *zName) { return 1; }
static int go_vtab_shadown_name_wrapper(const char *zName) { return true; }
int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags,
go_handle handle) {

View File

@@ -207,7 +207,10 @@ func TestAnyCollationNeeded(t *testing.T) {
t.Fatal(err)
}
db.AnyCollationNeeded()
err = db.AnyCollationNeeded()
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {

View File

@@ -618,6 +618,9 @@ func TestStmt_ColumnTime(t *testing.T) {
}
func TestStmt_Error(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
if bits.UintSize < 64 {
t.Skip("skipping on 32-bit")
}

61
util/vtabutil/const.go Normal file
View File

@@ -0,0 +1,61 @@
package vtabutil
const (
_NONE = iota
_MEMORY
_SYNTAX
_UNSUPPORTEDSQL
)
type ConflictClause uint32
const (
CONFLICT_NONE ConflictClause = iota
CONFLICT_ROLLBACK
CONFLICT_ABORT
CONFLICT_FAIL
CONFLICT_IGNORE
CONFLICT_REPLACE
)
type OrderClause uint32
const (
ORDER_NONE OrderClause = iota
ORDER_ASC
ORDER_DESC
)
type FKAction uint32
const (
FKACTION_NONE FKAction = iota
FKACTION_SETNULL
FKACTION_SETDEFAULT
FKACTION_CASCADE
FKACTION_RESTRICT
FKACTION_NOACTION
)
type FKDefType uint32
const (
DEFTYPE_NONE FKDefType = iota
DEFTYPE_DEFERRABLE
DEFTYPE_DEFERRABLE_INITIALLY_DEFERRED
DEFTYPE_DEFERRABLE_INITIALLY_IMMEDIATE
DEFTYPE_NOTDEFERRABLE
DEFTYPE_NOTDEFERRABLE_INITIALLY_DEFERRED
DEFTYPE_NOTDEFERRABLE_INITIALLY_IMMEDIATE
)
type StatementType uint32
const (
CREATE_UNKNOWN StatementType = iota
CREATE_TABLE
ALTER_RENAME_TABLE
ALTER_RENAME_COLUMN
ALTER_ADD_COLUMN
ALTER_DROP_COLUMN
)

View File

@@ -12,60 +12,50 @@ import (
)
const (
_NONE = iota
_MEMORY
_SYNTAX
_UNSUPPORTEDSQL
codeptr = 4
baseptr = 8
errp = 4
sqlp = 8
)
var (
//go:embed parse/sql3parse_table.wasm
binary []byte
ctx context.Context
once sync.Once
runtime wazero.Runtime
module wazero.CompiledModule
binary []byte
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
)
// Table holds metadata about a table.
type Table struct {
mod api.Module
ptr uint32
sql string
}
// Parse parses a [CREATE] or [ALTER TABLE] command.
//
// [CREATE]: https://sqlite.org/lang_createtable.html
// [ALTER TABLE]: https://sqlite.org/lang_altertable.html
func Parse(sql string) (_ *Table, err error) {
once.Do(func() {
ctx = context.Background()
cfg := wazero.NewRuntimeConfigInterpreter().WithDebugInfoEnabled(false)
ctx := context.Background()
cfg := wazero.NewRuntimeConfigInterpreter()
runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
module, err = runtime.CompileModule(ctx, binary)
compiled, err = runtime.CompileModule(ctx, binary)
})
if err != nil {
return nil, err
}
mod, err := runtime.InstantiateModule(ctx, module, wazero.NewModuleConfig().WithName(""))
ctx := context.Background()
mod, err := runtime.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName(""))
if err != nil {
return nil, err
}
defer mod.Close(ctx)
if buf, ok := mod.Memory().Read(baseptr, uint32(len(sql))); ok {
if buf, ok := mod.Memory().Read(sqlp, uint32(len(sql))); ok {
copy(buf, sql)
}
r, err := mod.ExportedFunction("sql3parse_table").Call(ctx, baseptr, uint64(len(sql)), codeptr)
r, err := mod.ExportedFunction("sql3parse_table").Call(ctx, sqlp, uint64(len(sql)), errp)
if err != nil {
return nil, err
}
c, _ := mod.Memory().ReadUint32Le(codeptr)
c, _ := mod.Memory().ReadUint32Le(errp)
switch c {
case _MEMORY:
panic(util.OOMErr)
@@ -74,68 +64,146 @@ func Parse(sql string) (_ *Table, err error) {
case _UNSUPPORTEDSQL:
return nil, util.ErrorString("sql3parse: unsupported SQL")
}
if r[0] == 0 {
return nil, nil
}
return &Table{
sql: sql,
mod: mod,
ptr: uint32(r[0]),
}, nil
var tab Table
tab.load(mod, uint32(r[0]), sql)
return &tab, nil
}
// Close closes a table handle.
func (t *Table) Close() error {
mod := t.mod
t.mod = nil
return mod.Close(ctx)
// Table holds metadata about a table.
type Table struct {
Name string
Schema string
Comment string
IsTemporary bool
IsIfNotExists bool
IsWithoutRowID bool
IsStrict bool
Columns []Column
Type StatementType
CurrentName string
NewName string
}
// NumColumns returns the number of columns of the table.
func (t *Table) NumColumns() int {
r, err := t.mod.ExportedFunction("sql3table_num_columns").Call(ctx, uint64(t.ptr))
if err != nil {
panic(err)
}
return int(int32(r[0]))
}
func (t *Table) load(mod api.Module, ptr uint32, sql string) {
t.Name = loadString(mod, ptr+0, sql)
t.Schema = loadString(mod, ptr+8, sql)
t.Comment = loadString(mod, ptr+16, sql)
// Column returns data for the ith column of the table.
//
// https://sqlite.org/lang_createtable.html#column_definitions
func (t *Table) Column(i int) Column {
r, err := t.mod.ExportedFunction("sql3table_get_column").Call(ctx, uint64(t.ptr), uint64(i))
if err != nil {
panic(err)
}
return Column{
tab: t,
ptr: uint32(r[0]),
}
}
t.IsTemporary = loadBool(mod, ptr+24)
t.IsIfNotExists = loadBool(mod, ptr+25)
t.IsWithoutRowID = loadBool(mod, ptr+26)
t.IsStrict = loadBool(mod, ptr+27)
func (t *Table) string(ptr uint32) string {
if ptr == 0 {
return ""
}
off, _ := t.mod.Memory().ReadUint32Le(ptr + 0)
len, _ := t.mod.Memory().ReadUint32Le(ptr + 4)
return t.sql[off-baseptr : off+len-baseptr]
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, res *Column) {
p, _ := mod.Memory().ReadUint32Le(ptr)
res.load(mod, p, sql)
})
t.Type = loadEnum[StatementType](mod, ptr+44)
t.CurrentName = loadString(mod, ptr+48, sql)
t.NewName = loadString(mod, ptr+56, sql)
}
// Column holds metadata about a column.
type Column struct {
tab *Table
ptr uint32
Name string
Type string
Length string
ConstraintName string
Comment string
IsPrimaryKey bool
IsAutoIncrement bool
IsNotNull bool
IsUnique bool
PKOrder OrderClause
PKConflictClause ConflictClause
NotNullConflictClause ConflictClause
UniqueConflictClause ConflictClause
CheckExpr string
DefaultExpr string
CollateName string
ForeignKeyClause *ForeignKey
}
// Type returns the declared type of a column.
//
// https://sqlite.org/lang_createtable.html#column_data_types
func (c Column) Type() string {
r, err := c.tab.mod.ExportedFunction("sql3column_type").Call(ctx, uint64(c.ptr))
if err != nil {
panic(err)
func (c *Column) load(mod api.Module, ptr uint32, sql string) {
c.Name = loadString(mod, ptr+0, sql)
c.Type = loadString(mod, ptr+8, sql)
c.Length = loadString(mod, ptr+16, sql)
c.ConstraintName = loadString(mod, ptr+24, sql)
c.Comment = loadString(mod, ptr+32, sql)
c.IsPrimaryKey = loadBool(mod, ptr+40)
c.IsAutoIncrement = loadBool(mod, ptr+41)
c.IsNotNull = loadBool(mod, ptr+42)
c.IsUnique = loadBool(mod, ptr+43)
c.PKOrder = loadEnum[OrderClause](mod, ptr+44)
c.PKConflictClause = loadEnum[ConflictClause](mod, ptr+48)
c.NotNullConflictClause = loadEnum[ConflictClause](mod, ptr+52)
c.UniqueConflictClause = loadEnum[ConflictClause](mod, ptr+56)
c.CheckExpr = loadString(mod, ptr+60, sql)
c.DefaultExpr = loadString(mod, ptr+68, sql)
c.CollateName = loadString(mod, ptr+76, sql)
if ptr, _ := mod.Memory().ReadUint32Le(ptr + 84); ptr != 0 {
c.ForeignKeyClause = &ForeignKey{}
c.ForeignKeyClause.load(mod, ptr, sql)
}
return c.tab.string(uint32(r[0]))
}
type ForeignKey struct {
Table string
Columns []string
OnDelete FKAction
OnUpdate FKAction
Match string
Deferrable FKDefType
}
func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) {
f.Table = loadString(mod, ptr+0, sql)
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, res *string) {
*res = loadString(mod, ptr, sql)
})
f.OnDelete = loadEnum[FKAction](mod, ptr+16)
f.OnUpdate = loadEnum[FKAction](mod, ptr+20)
f.Match = loadString(mod, ptr+24, sql)
f.Deferrable = loadEnum[FKDefType](mod, ptr+32)
}
func loadString(mod api.Module, ptr uint32, sql string) string {
off, _ := mod.Memory().ReadUint32Le(ptr + 0)
if off == 0 {
return ""
}
len, _ := mod.Memory().ReadUint32Le(ptr + 4)
return sql[off-sqlp : off+len-sqlp]
}
func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T {
ref, _ := mod.Memory().ReadUint32Le(ptr + 4)
if ref == 0 {
return nil
}
len, _ := mod.Memory().ReadUint32Le(ptr + 0)
res := make([]T, len)
for i := range res {
fn(ref, &res[i])
ref += 4
}
return res
}
func loadEnum[T ~uint32](mod api.Module, ptr uint32) T {
val, _ := mod.Memory().ReadUint32Le(ptr)
return T(val)
}
func loadBool(mod api.Module, ptr uint32) bool {
val, _ := mod.Memory().ReadByte(ptr)
return val != 0
}

View File

@@ -7,9 +7,8 @@ ROOT=../../../
BINARYEN="$ROOT/tools/binaryen-version_117/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -Oz \
-Wall -Wextra -Wno-unused-parameter -Wno-unused-function \
-o sql3parse_table.wasm sql3parse_table.c \
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -flto -g0 -Oz \
-Wall -Wextra -o sql3parse_table.wasm main.c \
-mexec-model=reactor \
-msimd128 -mmutable-globals -mmultivalue \
-mbulk-memory -mreference-types \
@@ -17,11 +16,11 @@ WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin"
-fno-stack-protector -fno-stack-clash-protection \
-Wl,--stack-first \
-Wl,--import-undefined \
$(awk '{print "-Wl,--export="$0}' exports.txt)
-Wl,--export=sql3parse_table
trap 'rm -f sql3parse_table.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sql3parse_table.wasm -o sql3parse_table.tmp
"$BINARYEN/wasm-opt" -g --strip --strip-producers -c -Oz \
"$BINARYEN/wasm-ctor-eval" -c _initialize sql3parse_table.wasm -o sql3parse_table.tmp
"$BINARYEN/wasm-opt" --strip --strip-debug --strip-producers -c -Oz \
sql3parse_table.tmp -o sql3parse_table.wasm \
--enable-simd --enable-mutable-globals --enable-multivalue \
--enable-bulk-memory --enable-reference-types \

View File

@@ -1,4 +0,0 @@
sql3parse_table
sql3table_get_column
sql3table_num_columns
sql3column_type

View File

@@ -0,0 +1,42 @@
#include <stddef.h>
#include "sql3parse_table.c"
static_assert(offsetof(sql3table, name) == 0, "Unexpected offset");
static_assert(offsetof(sql3table, schema) == 8, "Unexpected offset");
static_assert(offsetof(sql3table, comment) == 16, "Unexpected offset");
static_assert(offsetof(sql3table, is_temporary) == 24, "Unexpected offset");
static_assert(offsetof(sql3table, is_ifnotexists) == 25, "Unexpected offset");
static_assert(offsetof(sql3table, is_withoutrowid) == 26, "Unexpected offset");
static_assert(offsetof(sql3table, is_strict) == 27, "Unexpected offset");
static_assert(offsetof(sql3table, num_columns) == 28, "Unexpected offset");
static_assert(offsetof(sql3table, columns) == 32, "Unexpected offset");
static_assert(offsetof(sql3table, type) == 44, "Unexpected offset");
static_assert(offsetof(sql3table, current_name) == 48, "Unexpected offset");
static_assert(offsetof(sql3table, new_name) == 56, "Unexpected offset");
static_assert(offsetof(sql3column, name) == 0, "Unexpected offset");
static_assert(offsetof(sql3column, type) == 8, "Unexpected offset");
static_assert(offsetof(sql3column, length) == 16, "Unexpected offset");
static_assert(offsetof(sql3column, constraint_name) == 24, "Unexpected offset");
static_assert(offsetof(sql3column, comment) == 32, "Unexpected offset");
static_assert(offsetof(sql3column, is_primarykey) == 40, "Unexpected offset");
static_assert(offsetof(sql3column, is_autoincrement) == 41, "Unexpected offset");
static_assert(offsetof(sql3column, is_notnull) == 42, "Unexpected offset");
static_assert(offsetof(sql3column, is_unique) == 43, "Unexpected offset");
static_assert(offsetof(sql3column, pk_order) == 44, "Unexpected offset");
static_assert(offsetof(sql3column, pk_conflictclause) == 48, "Unexpected offset");
static_assert(offsetof(sql3column, notnull_conflictclause) == 52, "Unexpected offset");
static_assert(offsetof(sql3column, unique_conflictclause) == 56, "Unexpected offset");
static_assert(offsetof(sql3column, check_expr) == 60, "Unexpected offset");
static_assert(offsetof(sql3column, default_expr) == 68, "Unexpected offset");
static_assert(offsetof(sql3column, collate_name) == 76, "Unexpected offset");
static_assert(offsetof(sql3column, foreignkey_clause) == 84, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, table) == 0, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, num_columns) == 8, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, column_name) == 12, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, on_delete) == 16, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, on_update) == 20, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, match) == 24, "Unexpected offset");
static_assert(offsetof(sql3foreignkey, deferrable) == 32, "Unexpected offset");

View File

@@ -0,0 +1,31 @@
package vtabutil_test
import (
"testing"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
func TestParse(t *testing.T) {
tab, err := vtabutil.Parse(`CREATE TABLE child(x REFERENCES parent)`)
if err != nil {
t.Fatal(err)
}
if got := tab.Name; got != "child" {
t.Errorf("got %s, want child", got)
}
if got := len(tab.Columns); got != 1 {
t.Errorf("got %d, want 1", got)
}
col := tab.Columns[0]
if got := col.Name; got != "x" {
t.Errorf("got %s, want x", got)
}
fk := col.ForeignKeyClause
if got := fk.Table; got != "parent" {
t.Errorf("got %s, want parent", got)
}
}

View File

@@ -7,7 +7,7 @@ ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_117/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -O2 \
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -flto -g0 -O2 \
-o mptest.wasm main.c \
-I"$ROOT/sqlite3" \
-msimd128 -mmutable-globals \

View File

@@ -74,6 +74,8 @@ func initFlags() {
// keep test flags
os.Args[i] = arg
i++
case arg == "--":
// ignore this
default:
// collect everything else
options = append(options, arg)

View File

@@ -7,7 +7,7 @@ ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_117/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-22.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -std=c17 -flto -g0 -O2 \
"$WASI_SDK/clang" --target=wasm32-wasi -std=c23 -flto -g0 -O2 \
-o speedtest1.wasm main.c \
-I"$ROOT/sqlite3" \
-msimd128 -mmutable-globals \