Automatically load extensions. (#115)

This commit is contained in:
Nuno Cruces
2024-07-08 12:06:57 +01:00
committed by GitHub
parent fff8b1c74f
commit b5f746aadf
36 changed files with 261 additions and 245 deletions

View File

@@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
}
if err != nil {
return nil, err
}

View File

@@ -15,8 +15,8 @@ import (
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "array", nil,
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "array", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err

View File

@@ -15,10 +15,7 @@ import (
)
func Example_driver() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
log.Fatal(err)
}
@@ -53,14 +50,14 @@ func Example_driver() {
}
func Example() {
sqlite3.AutoExtension(array.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
array.Register(db)
stmt, _, err := db.Prepare(`
SELECT name
FROM pragma_function_list
@@ -91,10 +88,7 @@ func Example() {
func Test_cursor_Column(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
t.Fatal(err)
}
@@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
}
defer db.Close()
array.Register(db)
err = array.Register(db)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT * FROM array()`)
if err == nil {

View File

@@ -29,10 +29,11 @@ import (
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("readblob", 6, 0, readblob)
db.CreateFunction("writeblob", 6, 0, writeblob)
db.CreateFunction("openblob", -1, 0, openblob)
func Register(db *sqlite3.Conn) error {
return errors.Join(
db.CreateFunction("readblob", 6, 0, readblob),
db.CreateFunction("writeblob", 6, 0, writeblob),
db.CreateFunction("openblob", -1, 0, openblob))
}
// OpenCallback is the type for the openblob callback.

View File

@@ -18,10 +18,7 @@ import (
func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blobio.Register(conn)
return nil
})
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)
if err != nil {
log.Fatal(err)
@@ -60,6 +57,11 @@ func Example() {
// Hello BLOB!
}
func init() {
sqlite3.AutoExtension(blobio.Register)
sqlite3.AutoExtension(array.Register)
}
func Test_readblob(t *testing.T) {
t.Parallel()
@@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT readblob()`)
if err == nil {
t.Fatal("want error")
@@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
}
defer db.Close()
blobio.Register(db)
array.Register(db)
err = db.Exec(`SELECT openblob()`)
if err == nil {
t.Fatal("want error")

View File

@@ -20,8 +20,8 @@ import (
// Register registers the bloom_filter virtual table:
//
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "bloom_filter", create, connect)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
}
type bloom struct {

View File

@@ -12,6 +12,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func init() {
sqlite3.AutoExtension(bloom.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
bloom.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
@@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
}
defer db.Close()
bloom.Register(db)
query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
if err != nil {
t.Fatal(err)

View File

@@ -23,13 +23,13 @@ import (
// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
@@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
return table, nil
}
sqlite3.CreateModule(db, "csv", declare, declare)
return sqlite3.CreateModule(db, "csv", declare, declare)
}
type table struct {

View File

@@ -18,7 +18,10 @@ func Example() {
}
defer db.Close()
csv.Register(db)
err = csv.Register(db)
if err != nil {
log.Fatal(err)
}
err = db.Exec(`
CREATE VIRTUAL TABLE eurofxref USING csv(
@@ -51,6 +54,10 @@ func Example() {
// On Twosday, 1€ = $1.1342
}
func init() {
sqlite3.AutoExtension(csv.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
csv.Register(db)
const data = `
# Comment
"Rob" "Pike" rob
@@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
}
defer db.Close()
csv.Register(db)
const data = "01\n0.10\ne"
err = db.Exec(`
CREATE VIRTUAL TABLE temp.nums USING csv(
@@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
csv.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
if err == nil {
t.Fatal("want error")

View File

@@ -14,24 +14,26 @@ import (
// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}
// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
var err error
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})
return errors.Join(err,
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
}))
}
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -17,10 +17,7 @@ import (
func Test_lsmode(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.Register(c)
return nil
})
db, err := driver.Open(":memory:", fileio.Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) {
}
defer db.Close()
fileio.Register(db)
err = fileio.Register(db)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`SELECT name FROM fsdir()`)
if err == nil {

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -16,10 +15,7 @@ import (
func Test_writefile(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -21,47 +21,60 @@ package hash
import (
"crypto"
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers cryptographic hash functions for a database connection.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var errs util.ErrorJoiner
if crypto.MD4.Available() {
db.CreateFunction("md4", 1, flags, md4Func)
errs.Join(
db.CreateFunction("md4", 1, flags, md4Func))
}
if crypto.MD5.Available() {
db.CreateFunction("md5", 1, flags, md5Func)
errs.Join(
db.CreateFunction("md5", 1, flags, md5Func))
}
if crypto.SHA1.Available() {
db.CreateFunction("sha1", 1, flags, sha1Func)
errs.Join(
db.CreateFunction("sha1", 1, flags, sha1Func))
}
if crypto.SHA3_512.Available() {
db.CreateFunction("sha3", 1, flags, sha3Func)
db.CreateFunction("sha3", 2, flags, sha3Func)
errs.Join(
db.CreateFunction("sha3", 1, flags, sha3Func),
db.CreateFunction("sha3", 2, flags, sha3Func))
}
if crypto.SHA256.Available() {
db.CreateFunction("sha224", 1, flags, sha224Func)
db.CreateFunction("sha256", 1, flags, sha256Func)
db.CreateFunction("sha256", 2, flags, sha256Func)
errs.Join(
db.CreateFunction("sha224", 1, flags, sha224Func),
db.CreateFunction("sha256", 1, flags, sha256Func),
db.CreateFunction("sha256", 2, flags, sha256Func))
}
if crypto.SHA512.Available() {
db.CreateFunction("sha384", 1, flags, sha384Func)
db.CreateFunction("sha512", 1, flags, sha512Func)
db.CreateFunction("sha512", 2, flags, sha512Func)
errs.Join(
db.CreateFunction("sha384", 1, flags, sha384Func),
db.CreateFunction("sha512", 1, flags, sha512Func),
db.CreateFunction("sha512", 2, flags, sha512Func))
}
if crypto.BLAKE2s_256.Available() {
db.CreateFunction("blake2s", 1, flags, blake2sFunc)
errs.Join(
db.CreateFunction("blake2s", 1, flags, blake2sFunc))
}
if crypto.BLAKE2b_512.Available() {
db.CreateFunction("blake2b", 1, flags, blake2bFunc)
db.CreateFunction("blake2b", 2, flags, blake2bFunc)
errs.Join(
db.CreateFunction("blake2b", 1, flags, blake2bFunc),
db.CreateFunction("blake2b", 2, flags, blake2bFunc))
}
if crypto.RIPEMD160.Available() {
db.CreateFunction("ripemd160", 1, flags, ripemd160Func)
errs.Join(
db.CreateFunction("ripemd160", 1, flags, ripemd160Func))
}
return errors.Join(errs...)
}
func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -7,7 +7,6 @@ import (
_ "crypto/sha512"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -53,10 +52,7 @@ func TestRegister(t *testing.T) {
{"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"},
}
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -13,6 +13,7 @@ package lines
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/fs"
@@ -25,27 +26,28 @@ import (
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
sqlite3.CreateModule(db, "lines", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
return lines{}, err
})
sqlite3.CreateModule(db, "lines_read", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return lines{fsys}, err
})
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
return errors.Join(
sqlite3.CreateModule(db, "lines", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
return lines{}, err
}),
sqlite3.CreateModule(db, "lines_read", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return lines{fsys}, err
}))
}
type lines struct {

View File

@@ -18,10 +18,7 @@ import (
)
func Example() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -70,10 +67,7 @@ func Example() {
func Test_lines(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -103,10 +97,7 @@ func Test_lines(t *testing.T) {
func Test_lines_error(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -130,10 +121,7 @@ func Test_lines_error(t *testing.T) {
func Test_lines_read(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}
@@ -164,10 +152,7 @@ func Test_lines_read(t *testing.T) {
func Test_lines_test(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
lines.Register(c)
return nil
})
db, err := driver.Open(":memory:", lines.Register)
if err != nil {
log.Fatal(err)
}

View File

@@ -13,8 +13,8 @@ import (
)
// Register registers the pivot virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "pivot", declare, declare)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "pivot", declare, declare)
}
type table struct {

View File

@@ -14,14 +14,14 @@ import (
// https://antonz.org/sqlite-pivot-table/
func Example() {
sqlite3.AutoExtension(pivot.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE sales(product TEXT, year INT, income DECIMAL);
INSERT INTO sales(product, year, income) VALUES
@@ -83,6 +83,10 @@ func Example() {
// gamma 80 75 78 80
}
func init() {
sqlite3.AutoExtension(pivot.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -92,8 +96,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`
CREATE TABLE r AS
SELECT 1 id UNION SELECT 2 UNION SELECT 3;
@@ -153,8 +155,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
pivot.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`)
if err == nil {
t.Fatal("want error")

View File

@@ -12,19 +12,20 @@
package regexp
import (
"errors"
"regexp"
"github.com/ncruces/go-sqlite3"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("regexp_like", 2, flags, regexLike)
db.CreateFunction("regexp_substr", 2, flags, regexSubstr)
db.CreateFunction("regexp_replace", 3, flags, regexReplace)
return errors.Join(
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("regexp_like", 2, flags, regexLike),
db.CreateFunction("regexp_substr", 2, flags, regexSubstr),
db.CreateFunction("regexp_replace", 3, flags, regexReplace))
}
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {

View File

@@ -3,7 +3,6 @@ package regexp
import (
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -12,10 +11,7 @@ import (
func TestRegister(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
@@ -50,10 +46,7 @@ func TestRegister(t *testing.T) {
func TestRegister_errors(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -17,8 +17,8 @@ import (
)
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "statement", declare, declare)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "statement", declare, declare)
}
type table struct {

View File

@@ -12,14 +12,14 @@ import (
)
func Example() {
sqlite3.AutoExtension(statement.Register)
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE split_date USING statement((
SELECT
@@ -48,6 +48,10 @@ func Example() {
// Twosday was 2022-2-22
}
func init() {
sqlite3.AutoExtension(statement.Register)
}
func TestRegister(t *testing.T) {
t.Parallel()
@@ -57,8 +61,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()
statement.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE arguments USING statement((SELECT ? AS a, ? AS b, ? AS c))
`)
@@ -107,8 +109,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()
statement.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement()`)
if err == nil {
t.Fatal("want error")

View File

@@ -5,7 +5,6 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
@@ -18,8 +17,6 @@ func TestRegister_boolean(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)

View File

@@ -6,7 +6,6 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
@@ -19,8 +18,6 @@ func TestRegister_percentile(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)

View File

@@ -44,33 +44,38 @@
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
db.CreateWindowFunction("median", 1, flags, newPercentile(median))
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont))
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc))
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
return errors.Join(
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)),
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)),
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)),
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)),
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)),
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)),
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
db.CreateWindowFunction("median", 1, flags, newPercentile(median)),
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
}
const (

View File

@@ -10,6 +10,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)
func init() {
sqlite3.AutoExtension(stats.Register)
}
func TestRegister_variance(t *testing.T) {
t.Parallel()
@@ -19,8 +23,6 @@ func TestRegister_variance(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)
@@ -88,8 +90,6 @@ func TestRegister_covariance(t *testing.T) {
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (y, x)`)
if err != nil {
t.Fatal(err)
@@ -217,8 +217,6 @@ func Benchmark_variance(b *testing.B) {
}
defer db.Close()
stats.Register(db)
stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)

View File

@@ -18,6 +18,7 @@ package unicode
import (
"bytes"
"errors"
"regexp"
"strings"
"unicode/utf8"
@@ -30,29 +31,29 @@ import (
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("like", 2, flags, like),
db.CreateFunction("like", 3, flags, like),
db.CreateFunction("upper", 1, flags, upper),
db.CreateFunction("upper", 2, flags, upper),
db.CreateFunction("lower", 1, flags, lower),
db.CreateFunction("lower", 2, flags, lower),
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
db.CreateFunction("like", 2, flags, like)
db.CreateFunction("like", 3, flags, like)
db.CreateFunction("upper", 1, flags, upper)
db.CreateFunction("upper", 2, flags, upper)
db.CreateFunction("lower", 1, flags, lower)
db.CreateFunction("lower", 2, flags, lower)
db.CreateFunction("regexp", 2, flags, regex)
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
})
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
}
}))
}
// RegisterCollation registers a Unicode collation sequence for a database connection.

View File

@@ -5,6 +5,7 @@ package uuid
import (
"bytes"
"errors"
"fmt"
"github.com/google/uuid"
@@ -25,14 +26,15 @@ import (
// uuid_blob(u)
//
// Converts a UUID into a 16-byte blob.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate)
db.CreateFunction("uuid_str", 1, flags, toString)
db.CreateFunction("uuid_blob", 1, flags, toBLOB)
return errors.Join(
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
@@ -147,7 +149,7 @@ func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) {
return u, err
}
func toBLOB(ctx sqlite3.Context, arg ...sqlite3.Value) {
func toBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)

View File

@@ -4,7 +4,6 @@ import (
"testing"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -13,10 +12,7 @@ import (
func Test_generate(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
@@ -135,10 +131,7 @@ func Test_generate(t *testing.T) {
func Test_convert(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
Register(conn)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -4,15 +4,18 @@
package zorder
import (
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the zorder and unzorder SQL functions.
func Register(db *sqlite3.Conn) {
func Register(db *sqlite3.Conn) error {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateFunction("zorder", -1, flags, zorder)
db.CreateFunction("unzorder", 3, flags, unzorder)
return errors.Join(
db.CreateFunction("zorder", -1, flags, zorder),
db.CreateFunction("unzorder", 3, flags, unzorder))
}
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {

View File

@@ -3,7 +3,6 @@ package zorder_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/zorder"
@@ -13,10 +12,7 @@ import (
func TestRegister_zorder(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -60,10 +56,7 @@ func TestRegister_zorder(t *testing.T) {
func TestRegister_unzorder(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}
@@ -90,10 +83,7 @@ func TestRegister_unzorder(t *testing.T) {
func TestRegister_error(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
zorder.Register(c)
return nil
})
db, err := driver.Open(":memory:", zorder.Register)
if err != nil {
t.Fatal(err)
}

View File

@@ -31,8 +31,9 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
//
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
func (c Conn) AnyCollationNeeded() error {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
return c.error(r)
}
// CreateCollation defines a new collating sequence.

View File

@@ -104,3 +104,13 @@ func ErrorCodeString(rc uint32) string {
}
return "sqlite3: unknown error"
}
type ErrorJoiner []error
func (j *ErrorJoiner) Join(errs ...error) {
for _, err := range errs {
if err != nil {
*j = append(*j, err)
}
}
}

30
registry.go Normal file
View File

@@ -0,0 +1,30 @@
package sqlite3
import "sync"
var (
// +checklocks:extRegistryMtx
extRegistry []func(*Conn) error
extRegistryMtx sync.RWMutex
)
// AutoExtension causes the entryPoint function to be invoked
// for each new database connection that is created.
//
// https://sqlite.org/c3ref/auto_extension.html
func AutoExtension(entryPoint func(*Conn) error) {
extRegistryMtx.Lock()
defer extRegistryMtx.Unlock()
extRegistry = append(extRegistry, entryPoint)
}
func initExtensions(c *Conn) error {
extRegistryMtx.RLock()
defer extRegistryMtx.RUnlock()
for _, f := range extRegistry {
if err := f(c); err != nil {
return err
}
}
return nil
}

View File

@@ -207,7 +207,10 @@ func TestAnyCollationNeeded(t *testing.T) {
t.Fatal(err)
}
db.AnyCollationNeeded()
err = db.AnyCollationNeeded()
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {