From d9fcf60b7d4b89819e7a863821db1dcb7c9e4f76 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 20 Sep 2023 02:41:09 +0100 Subject: [PATCH] Driver API. --- conn.go | 18 +++++--------- driver/driver.go | 52 ++++++++++++++++++++++++++++++++++------- driver_test.go | 2 +- gormlite/sqlite.go | 5 ++-- gormlite/sqlite_test.go | 41 ++++++++++++++++++++++++++++++++ tests/driver_test.go | 5 ++-- 6 files changed, 96 insertions(+), 27 deletions(-) diff --git a/conn.go b/conn.go index 9c15a85..3d19b70 100644 --- a/conn.go +++ b/conn.go @@ -2,7 +2,6 @@ package sqlite3 import ( "context" - "database/sql/driver" "errors" "fmt" "net/url" @@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 { // // https://www.sqlite.org/c3ref/interrupt.html func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { + // Is it the same context? + if ctx == c.interrupt { + return ctx + } + // Is a waiter running? if c.waiter != nil { c.waiter <- struct{}{} // Cancel the waiter. @@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error { // [online backup]: https://www.sqlite.org/backup.html // [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html type DriverConn interface { - driver.Conn - driver.ConnBeginTx - driver.ExecerContext - driver.ConnPrepareContext - - SetInterrupt(ctx context.Context) (old context.Context) - - Savepoint() Savepoint - Backup(srcDB, dstURI string) error - Restore(dstDB, srcURI string) error - OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) + Raw() *Conn } diff --git a/driver/driver.go b/driver/driver.go index e89adc7..220263b 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -40,14 +40,34 @@ import ( "github.com/ncruces/go-sqlite3/internal/util" ) +// This variable can be replaced with -ldflags: +// +// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite" +var driverName = "sqlite3" + func init() { - sql.Register("sqlite3", sqlite{}) + if driverName != "" { + sql.Register(driverName, sqlite{}) + } +} + +// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB]. +// +// The init function is called by the driver on new connections. +// The conn can be used to execute queries, register functions, etc. +// Any error return closes the conn and passes the error to database/sql. +func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) { + c, err := newConnector(dataSourceName, init) + if err != nil { + return nil, err + } + return sql.OpenDB(c), nil } type sqlite struct{} func (sqlite) Open(name string) (driver.Conn, error) { - c, err := sqlite{}.OpenConnector(name) + c, err := newConnector(name, nil) if err != nil { return nil, err } @@ -55,7 +75,11 @@ func (sqlite) Open(name string) (driver.Conn, error) { } func (sqlite) OpenConnector(name string) (driver.Connector, error) { - c := connector{name: name} + return newConnector(name, nil) +} + +func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) { + c := connector{name: name, init: init} if strings.HasPrefix(name, "file:") { if _, after, ok := strings.Cut(name, "?"); ok { query, err := url.ParseQuery(after) @@ -73,6 +97,7 @@ type connector struct { name string txlock string pragmas bool + init func(ctx context.Context, conn *sqlite3.Conn) error } func (n *connector) Driver() driver.Driver { @@ -126,6 +151,12 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { return nil, err } } + if n.init != nil { + err = n.init(ctx, c.Conn) + if err != nil { + return nil, err + } + } return &c, nil } @@ -140,12 +171,17 @@ type conn struct { var ( // Ensure these interfaces are implemented: - _ driver.ExecerContext = &conn{} - _ driver.ConnBeginTx = &conn{} - _ driver.Validator = &conn{} - _ sqlite3.DriverConn = &conn{} + _ driver.ConnPrepareContext = &conn{} + _ driver.ExecerContext = &conn{} + _ driver.ConnBeginTx = &conn{} + _ driver.Validator = &conn{} + _ sqlite3.DriverConn = &conn{} ) +func (c *conn) Raw() *sqlite3.Conn { + return c.Conn +} + func (c *conn) IsValid() bool { return c.reusable } @@ -190,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e func (c *conn) Commit() error { err := c.Conn.Exec(c.txCommit) - if err != nil && !c.GetAutocommit() { + if err != nil && !c.Conn.GetAutocommit() { c.Rollback() } return err diff --git a/driver_test.go b/driver_test.go index 3b5bad5..63b0c75 100644 --- a/driver_test.go +++ b/driver_test.go @@ -47,7 +47,7 @@ func ExampleDriverConn() { } err = conn.Raw(func(driverConn any) error { - conn := driverConn.(sqlite3.DriverConn) + conn := driverConn.(sqlite3.DriverConn).Raw() savept := conn.Savepoint() defer savept.Release(&err) diff --git a/gormlite/sqlite.go b/gormlite/sqlite.go index 960385c..76433d4 100644 --- a/gormlite/sqlite.go +++ b/gormlite/sqlite.go @@ -3,7 +3,6 @@ package gormlite import ( "context" - "database/sql" "strconv" "gorm.io/gorm" @@ -13,7 +12,7 @@ import ( "gorm.io/gorm/migrator" "gorm.io/gorm/schema" - _ "github.com/ncruces/go-sqlite3/driver" + "github.com/ncruces/go-sqlite3/driver" ) type Dialector struct { @@ -33,7 +32,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { if dialector.Conn != nil { db.ConnPool = dialector.Conn } else { - conn, err := sql.Open("sqlite3", dialector.DSN) + conn, err := driver.Open(dialector.DSN, nil) if err != nil { return err } diff --git a/gormlite/sqlite_test.go b/gormlite/sqlite_test.go index dc917c7..f2ab58a 100644 --- a/gormlite/sqlite_test.go +++ b/gormlite/sqlite_test.go @@ -1,11 +1,14 @@ package gormlite import ( + "context" "fmt" "testing" "gorm.io/gorm" + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" ) @@ -13,6 +16,17 @@ func TestDialector(t *testing.T) { // This is the DSN of the in-memory SQLite database for these tests. const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared" + // Custom connection with a custom function called "my_custom_function". + conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error { + return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC, + func(ctx sqlite3.Context, arg ...sqlite3.Value) { + ctx.ResultText("my-result") + }) + }) + if err != nil { + t.Fatal(err) + } + rows := []struct { description string dialector *Dialector @@ -29,6 +43,33 @@ func TestDialector(t *testing.T) { query: "SELECT 1", querySuccess: true, }, + { + description: "Custom function", + dialector: &Dialector{ + DSN: InMemoryDSN, + }, + openSuccess: true, + query: "SELECT my_custom_function()", + querySuccess: false, + }, + { + description: "Custom connection", + dialector: &Dialector{ + Conn: conn, + }, + openSuccess: true, + query: "SELECT 1", + querySuccess: true, + }, + { + description: "Custom connection, custom function", + dialector: &Dialector{ + Conn: conn, + }, + openSuccess: true, + query: "SELECT my_custom_function()", + querySuccess: true, + }, } for rowIndex, row := range rows { t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) { diff --git a/tests/driver_test.go b/tests/driver_test.go index 94a5c97..d46e35d 100644 --- a/tests/driver_test.go +++ b/tests/driver_test.go @@ -2,10 +2,9 @@ package tests import ( "context" - "database/sql" "testing" - _ "github.com/ncruces/go-sqlite3/driver" + "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" ) @@ -15,7 +14,7 @@ func TestDriver(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - db, err := sql.Open("sqlite3", ":memory:") + db, err := driver.Open(":memory:", nil) if err != nil { t.Fatal(err) }