mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Automatically load extensions. (#115)
This commit is contained in:
3
conn.go
3
conn.go
@@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
c.arena = c.newArena(1024)
|
||||
c.ctx = context.WithValue(c.ctx, connKey{}, c)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err == nil {
|
||||
err = initExtensions(c)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
// The argument must be bound to a Go slice or array of
|
||||
// ints, floats, bools, strings or byte slices,
|
||||
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
|
||||
func Register(db *sqlite3.Conn) {
|
||||
sqlite3.CreateModule(db, "array", nil,
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return sqlite3.CreateModule(db, "array", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
|
||||
return array{}, err
|
||||
|
||||
@@ -15,10 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func Example_driver() {
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
array.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", array.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -53,14 +50,14 @@ func Example_driver() {
|
||||
}
|
||||
|
||||
func Example() {
|
||||
sqlite3.AutoExtension(array.Register)
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
array.Register(db)
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT name
|
||||
FROM pragma_function_list
|
||||
@@ -91,10 +88,7 @@ func Example() {
|
||||
func Test_cursor_Column(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
array.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", array.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
array.Register(db)
|
||||
err = array.Register(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT * FROM array()`)
|
||||
if err == nil {
|
||||
|
||||
@@ -29,10 +29,11 @@ import (
|
||||
// along with the [sqlite3.Blob] handle.
|
||||
//
|
||||
// https://sqlite.org/c3ref/blob.html
|
||||
func Register(db *sqlite3.Conn) {
|
||||
db.CreateFunction("readblob", 6, 0, readblob)
|
||||
db.CreateFunction("writeblob", 6, 0, writeblob)
|
||||
db.CreateFunction("openblob", -1, 0, openblob)
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return errors.Join(
|
||||
db.CreateFunction("readblob", 6, 0, readblob),
|
||||
db.CreateFunction("writeblob", 6, 0, writeblob),
|
||||
db.CreateFunction("openblob", -1, 0, openblob))
|
||||
}
|
||||
|
||||
// OpenCallback is the type for the openblob callback.
|
||||
|
||||
@@ -18,10 +18,7 @@ import (
|
||||
|
||||
func Example() {
|
||||
// Open the database, registering the extension.
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
|
||||
blobio.Register(conn)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -60,6 +57,11 @@ func Example() {
|
||||
// Hello BLOB!
|
||||
}
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(blobio.Register)
|
||||
sqlite3.AutoExtension(array.Register)
|
||||
}
|
||||
|
||||
func Test_readblob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
blobio.Register(db)
|
||||
array.Register(db)
|
||||
|
||||
err = db.Exec(`SELECT readblob()`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
@@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
blobio.Register(db)
|
||||
array.Register(db)
|
||||
|
||||
err = db.Exec(`SELECT openblob()`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
// Register registers the bloom_filter virtual table:
|
||||
//
|
||||
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
|
||||
func Register(db *sqlite3.Conn) {
|
||||
sqlite3.CreateModule(db, "bloom_filter", create, connect)
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
|
||||
}
|
||||
|
||||
type bloom struct {
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(bloom.Register)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
bloom.Register(db)
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
|
||||
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
|
||||
@@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
bloom.Register(db)
|
||||
|
||||
query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
|
||||
// Register registers the CSV virtual table.
|
||||
// If a filename is specified, [os.Open] is used to open the file.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
RegisterFS(db, osutil.FS{})
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return RegisterFS(db, osutil.FS{})
|
||||
}
|
||||
|
||||
// RegisterFS registers the CSV virtual table.
|
||||
// If a filename is specified, fsys is used to open the file.
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
|
||||
var (
|
||||
filename string
|
||||
@@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
return table, nil
|
||||
}
|
||||
|
||||
sqlite3.CreateModule(db, "csv", declare, declare)
|
||||
return sqlite3.CreateModule(db, "csv", declare, declare)
|
||||
}
|
||||
|
||||
type table struct {
|
||||
|
||||
@@ -18,7 +18,10 @@ func Example() {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
csv.Register(db)
|
||||
err = csv.Register(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE eurofxref USING csv(
|
||||
@@ -51,6 +54,10 @@ func Example() {
|
||||
// On Twosday, 1€ = $1.1342
|
||||
}
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(csv.Register)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
csv.Register(db)
|
||||
|
||||
const data = `
|
||||
# Comment
|
||||
"Rob" "Pike" rob
|
||||
@@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
csv.Register(db)
|
||||
|
||||
const data = "01\n0.10\ne"
|
||||
err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE temp.nums USING csv(
|
||||
@@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
csv.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
|
||||
@@ -14,24 +14,26 @@ import (
|
||||
|
||||
// Register registers SQL functions readfile, writefile, lsmode,
|
||||
// and the table-valued function fsdir.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
RegisterFS(db, nil)
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return RegisterFS(db, nil)
|
||||
}
|
||||
|
||||
// Register registers SQL functions readfile, lsmode,
|
||||
// and the table-valued function fsdir;
|
||||
// fsys will be used to read files and list directories.
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
|
||||
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
var err error
|
||||
if fsys == nil {
|
||||
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
|
||||
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
|
||||
}
|
||||
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
return fsdir{fsys}, err
|
||||
})
|
||||
return errors.Join(err,
|
||||
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
|
||||
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
|
||||
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
return fsdir{fsys}, err
|
||||
}))
|
||||
}
|
||||
|
||||
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
|
||||
@@ -17,10 +17,7 @@ import (
|
||||
func Test_lsmode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
fileio.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", fileio.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fileio.Register(db)
|
||||
err = fileio.Register(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`SELECT name FROM fsdir()`)
|
||||
if err == nil {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
@@ -16,10 +15,7 @@ import (
|
||||
func Test_writefile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -21,47 +21,60 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Register registers cryptographic hash functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
var errs util.ErrorJoiner
|
||||
if crypto.MD4.Available() {
|
||||
db.CreateFunction("md4", 1, flags, md4Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("md4", 1, flags, md4Func))
|
||||
}
|
||||
if crypto.MD5.Available() {
|
||||
db.CreateFunction("md5", 1, flags, md5Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("md5", 1, flags, md5Func))
|
||||
}
|
||||
if crypto.SHA1.Available() {
|
||||
db.CreateFunction("sha1", 1, flags, sha1Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("sha1", 1, flags, sha1Func))
|
||||
}
|
||||
if crypto.SHA3_512.Available() {
|
||||
db.CreateFunction("sha3", 1, flags, sha3Func)
|
||||
db.CreateFunction("sha3", 2, flags, sha3Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("sha3", 1, flags, sha3Func),
|
||||
db.CreateFunction("sha3", 2, flags, sha3Func))
|
||||
}
|
||||
if crypto.SHA256.Available() {
|
||||
db.CreateFunction("sha224", 1, flags, sha224Func)
|
||||
db.CreateFunction("sha256", 1, flags, sha256Func)
|
||||
db.CreateFunction("sha256", 2, flags, sha256Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("sha224", 1, flags, sha224Func),
|
||||
db.CreateFunction("sha256", 1, flags, sha256Func),
|
||||
db.CreateFunction("sha256", 2, flags, sha256Func))
|
||||
}
|
||||
if crypto.SHA512.Available() {
|
||||
db.CreateFunction("sha384", 1, flags, sha384Func)
|
||||
db.CreateFunction("sha512", 1, flags, sha512Func)
|
||||
db.CreateFunction("sha512", 2, flags, sha512Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("sha384", 1, flags, sha384Func),
|
||||
db.CreateFunction("sha512", 1, flags, sha512Func),
|
||||
db.CreateFunction("sha512", 2, flags, sha512Func))
|
||||
}
|
||||
if crypto.BLAKE2s_256.Available() {
|
||||
db.CreateFunction("blake2s", 1, flags, blake2sFunc)
|
||||
errs.Join(
|
||||
db.CreateFunction("blake2s", 1, flags, blake2sFunc))
|
||||
}
|
||||
if crypto.BLAKE2b_512.Available() {
|
||||
db.CreateFunction("blake2b", 1, flags, blake2bFunc)
|
||||
db.CreateFunction("blake2b", 2, flags, blake2bFunc)
|
||||
errs.Join(
|
||||
db.CreateFunction("blake2b", 1, flags, blake2bFunc),
|
||||
db.CreateFunction("blake2b", 2, flags, blake2bFunc))
|
||||
}
|
||||
if crypto.RIPEMD160.Available() {
|
||||
db.CreateFunction("ripemd160", 1, flags, ripemd160Func)
|
||||
errs.Join(
|
||||
db.CreateFunction("ripemd160", 1, flags, ripemd160Func))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
_ "crypto/sha512"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
@@ -53,10 +52,7 @@ func TestRegister(t *testing.T) {
|
||||
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
|
||||
}
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ package lines
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -25,27 +26,28 @@ import (
|
||||
// The lines function reads from a database blob or text.
|
||||
// The lines_read function reads from a file or an [io.Reader].
|
||||
// If a filename is specified, [os.Open] is used to open the file.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
RegisterFS(db, osutil.FS{})
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return RegisterFS(db, osutil.FS{})
|
||||
}
|
||||
|
||||
// RegisterFS registers the lines and lines_read table-valued functions.
|
||||
// The lines function reads from a database blob or text.
|
||||
// The lines_read function reads from a file or an [io.Reader].
|
||||
// If a filename is specified, fsys is used to open the file.
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
sqlite3.CreateModule(db, "lines", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
|
||||
return lines{}, err
|
||||
})
|
||||
sqlite3.CreateModule(db, "lines_read", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
return lines{fsys}, err
|
||||
})
|
||||
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
|
||||
return errors.Join(
|
||||
sqlite3.CreateModule(db, "lines", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
|
||||
return lines{}, err
|
||||
}),
|
||||
sqlite3.CreateModule(db, "lines_read", nil,
|
||||
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
|
||||
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
|
||||
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
|
||||
return lines{fsys}, err
|
||||
}))
|
||||
}
|
||||
|
||||
type lines struct {
|
||||
|
||||
@@ -18,10 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func Example() {
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
lines.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -70,10 +67,7 @@ func Example() {
|
||||
func Test_lines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
lines.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -103,10 +97,7 @@ func Test_lines(t *testing.T) {
|
||||
func Test_lines_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
lines.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -130,10 +121,7 @@ func Test_lines_error(t *testing.T) {
|
||||
func Test_lines_read(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
lines.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -164,10 +152,7 @@ func Test_lines_read(t *testing.T) {
|
||||
func Test_lines_test(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
lines.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", lines.Register)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
// Register registers the pivot virtual table.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
sqlite3.CreateModule(db, "pivot", declare, declare)
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return sqlite3.CreateModule(db, "pivot", declare, declare)
|
||||
}
|
||||
|
||||
type table struct {
|
||||
|
||||
@@ -14,14 +14,14 @@ import (
|
||||
|
||||
// https://antonz.org/sqlite-pivot-table/
|
||||
func Example() {
|
||||
sqlite3.AutoExtension(pivot.Register)
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
pivot.Register(db)
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE TABLE sales(product TEXT, year INT, income DECIMAL);
|
||||
INSERT INTO sales(product, year, income) VALUES
|
||||
@@ -83,6 +83,10 @@ func Example() {
|
||||
// gamma 80 75 78 80
|
||||
}
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(pivot.Register)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -92,8 +96,6 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
pivot.Register(db)
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE TABLE r AS
|
||||
SELECT 1 id UNION SELECT 2 UNION SELECT 3;
|
||||
@@ -153,8 +155,6 @@ func TestRegister_errors(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
pivot.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
|
||||
@@ -12,19 +12,20 @@
|
||||
package regexp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("regexp_like", 2, flags, regexLike)
|
||||
db.CreateFunction("regexp_substr", 2, flags, regexSubstr)
|
||||
db.CreateFunction("regexp_replace", 3, flags, regexReplace)
|
||||
return errors.Join(
|
||||
db.CreateFunction("regexp", 2, flags, regex),
|
||||
db.CreateFunction("regexp_like", 2, flags, regexLike),
|
||||
db.CreateFunction("regexp_substr", 2, flags, regexSubstr),
|
||||
db.CreateFunction("regexp_replace", 3, flags, regexReplace))
|
||||
}
|
||||
|
||||
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package regexp
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
@@ -12,10 +11,7 @@ import (
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -50,10 +46,7 @@ func TestRegister(t *testing.T) {
|
||||
func TestRegister_errors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
)
|
||||
|
||||
// Register registers the statement virtual table.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
sqlite3.CreateModule(db, "statement", declare, declare)
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
return sqlite3.CreateModule(db, "statement", declare, declare)
|
||||
}
|
||||
|
||||
type table struct {
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
)
|
||||
|
||||
func Example() {
|
||||
sqlite3.AutoExtension(statement.Register)
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statement.Register(db)
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE split_date USING statement((
|
||||
SELECT
|
||||
@@ -48,6 +48,10 @@ func Example() {
|
||||
// Twosday was 2022-2-22
|
||||
}
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(statement.Register)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -57,8 +61,6 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statement.Register(db)
|
||||
|
||||
err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE arguments USING statement((SELECT ? AS a, ? AS b, ? AS c))
|
||||
`)
|
||||
@@ -107,8 +109,6 @@ func TestRegister_errors(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statement.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement()`)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/stats"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
@@ -18,8 +17,6 @@ func TestRegister_boolean(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/stats"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
@@ -19,8 +18,6 @@ func TestRegister_percentile(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -44,33 +44,38 @@
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
|
||||
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
|
||||
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
|
||||
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
|
||||
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
|
||||
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
|
||||
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
|
||||
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
|
||||
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
|
||||
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
|
||||
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
|
||||
db.CreateWindowFunction("median", 1, flags, newPercentile(median))
|
||||
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont))
|
||||
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc))
|
||||
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
|
||||
return errors.Join(
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
|
||||
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)),
|
||||
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)),
|
||||
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)),
|
||||
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)),
|
||||
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)),
|
||||
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)),
|
||||
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
|
||||
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
|
||||
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
|
||||
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
|
||||
db.CreateWindowFunction("median", 1, flags, newPercentile(median)),
|
||||
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)),
|
||||
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)),
|
||||
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -10,6 +10,10 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sqlite3.AutoExtension(stats.Register)
|
||||
}
|
||||
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -19,8 +23,6 @@ func TestRegister_variance(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -88,8 +90,6 @@ func TestRegister_covariance(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (y, x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -217,8 +217,6 @@ func Benchmark_variance(b *testing.B) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
||||
@@ -18,6 +18,7 @@ package unicode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@@ -30,29 +31,29 @@ import (
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
return errors.Join(
|
||||
db.CreateFunction("like", 2, flags, like),
|
||||
db.CreateFunction("like", 3, flags, like),
|
||||
db.CreateFunction("upper", 1, flags, upper),
|
||||
db.CreateFunction("upper", 2, flags, upper),
|
||||
db.CreateFunction("lower", 1, flags, lower),
|
||||
db.CreateFunction("lower", 2, flags, lower),
|
||||
db.CreateFunction("regexp", 2, flags, regex),
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
db.CreateFunction("like", 2, flags, like)
|
||||
db.CreateFunction("like", 3, flags, like)
|
||||
db.CreateFunction("upper", 1, flags, upper)
|
||||
db.CreateFunction("upper", 2, flags, upper)
|
||||
db.CreateFunction("lower", 1, flags, lower)
|
||||
db.CreateFunction("lower", 2, flags, lower)
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
name := arg[1].Text()
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
|
||||
@@ -5,6 +5,7 @@ package uuid
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -25,14 +26,15 @@ import (
|
||||
// uuid_blob(u)
|
||||
//
|
||||
// Converts a UUID into a 16-byte blob.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate)
|
||||
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate)
|
||||
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate)
|
||||
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate)
|
||||
db.CreateFunction("uuid_str", 1, flags, toString)
|
||||
db.CreateFunction("uuid_blob", 1, flags, toBLOB)
|
||||
return errors.Join(
|
||||
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
|
||||
db.CreateFunction("uuid_str", 1, flags, toString),
|
||||
db.CreateFunction("uuid_blob", 1, flags, toBlob))
|
||||
}
|
||||
|
||||
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
@@ -147,7 +149,7 @@ func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) {
|
||||
return u, err
|
||||
}
|
||||
|
||||
func toBLOB(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
func toBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
u, err := fromValue(arg[0])
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
@@ -13,10 +12,7 @@ import (
|
||||
func Test_generate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -135,10 +131,7 @@ func Test_generate(t *testing.T) {
|
||||
func Test_convert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -4,15 +4,18 @@
|
||||
package zorder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Register registers the zorder and unzorder SQL functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateFunction("zorder", -1, flags, zorder)
|
||||
db.CreateFunction("unzorder", 3, flags, unzorder)
|
||||
return errors.Join(
|
||||
db.CreateFunction("zorder", -1, flags, zorder),
|
||||
db.CreateFunction("unzorder", 3, flags, unzorder))
|
||||
}
|
||||
|
||||
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package zorder_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/zorder"
|
||||
@@ -13,10 +12,7 @@ import (
|
||||
func TestRegister_zorder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
zorder.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -60,10 +56,7 @@ func TestRegister_zorder(t *testing.T) {
|
||||
func TestRegister_unzorder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
zorder.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -90,10 +83,7 @@ func TestRegister_unzorder(t *testing.T) {
|
||||
func TestRegister_error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
|
||||
zorder.Register(c)
|
||||
return nil
|
||||
})
|
||||
db, err := driver.Open(":memory:", zorder.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
5
func.go
5
func.go
@@ -31,8 +31,9 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
|
||||
//
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
|
||||
func (c Conn) AnyCollationNeeded() error {
|
||||
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// CreateCollation defines a new collating sequence.
|
||||
|
||||
@@ -104,3 +104,13 @@ func ErrorCodeString(rc uint32) string {
|
||||
}
|
||||
return "sqlite3: unknown error"
|
||||
}
|
||||
|
||||
type ErrorJoiner []error
|
||||
|
||||
func (j *ErrorJoiner) Join(errs ...error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
*j = append(*j, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
30
registry.go
Normal file
30
registry.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package sqlite3
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
// +checklocks:extRegistryMtx
|
||||
extRegistry []func(*Conn) error
|
||||
extRegistryMtx sync.RWMutex
|
||||
)
|
||||
|
||||
// AutoExtension causes the entryPoint function to be invoked
|
||||
// for each new database connection that is created.
|
||||
//
|
||||
// https://sqlite.org/c3ref/auto_extension.html
|
||||
func AutoExtension(entryPoint func(*Conn) error) {
|
||||
extRegistryMtx.Lock()
|
||||
defer extRegistryMtx.Unlock()
|
||||
extRegistry = append(extRegistry, entryPoint)
|
||||
}
|
||||
|
||||
func initExtensions(c *Conn) error {
|
||||
extRegistryMtx.RLock()
|
||||
defer extRegistryMtx.RUnlock()
|
||||
for _, f := range extRegistry {
|
||||
if err := f(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -207,7 +207,10 @@ func TestAnyCollationNeeded(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.AnyCollationNeeded()
|
||||
err = db.AnyCollationNeeded()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user