Driver API.

This commit is contained in:
Nuno Cruces
2023-09-20 02:41:09 +01:00
parent ac6dd1aa5f
commit d9fcf60b7d
6 changed files with 96 additions and 27 deletions

18
conn.go
View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
@@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 {
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is it the same context?
if ctx == c.interrupt {
return ctx
}
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
@@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
Raw() *Conn
}

View File

@@ -40,14 +40,34 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
var driverName = "sqlite3"
func init() {
sql.Register("sqlite3", sqlite{})
if driverName != "" {
sql.Register(driverName, sqlite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to database/sql.
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite{}.OpenConnector(name)
c, err := newConnector(name, nil)
if err != nil {
return nil, err
}
@@ -55,7 +75,11 @@ func (sqlite) Open(name string) (driver.Conn, error) {
}
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
c := connector{name: name}
return newConnector(name, nil)
}
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
@@ -73,6 +97,7 @@ type connector struct {
name string
txlock string
pragmas bool
init func(ctx context.Context, conn *sqlite3.Conn) error
}
func (n *connector) Driver() driver.Driver {
@@ -126,6 +151,12 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
return nil, err
}
}
if n.init != nil {
err = n.init(ctx, c.Conn)
if err != nil {
return nil, err
}
}
return &c, nil
}
@@ -140,12 +171,17 @@ type conn struct {
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) IsValid() bool {
return c.reusable
}
@@ -190,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err

View File

@@ -47,7 +47,7 @@ func ExampleDriverConn() {
}
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
conn := driverConn.(sqlite3.DriverConn).Raw()
savept := conn.Savepoint()
defer savept.Release(&err)

View File

@@ -3,7 +3,6 @@ package gormlite
import (
"context"
"database/sql"
"strconv"
"gorm.io/gorm"
@@ -13,7 +12,7 @@ import (
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
@@ -33,7 +32,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
conn, err := driver.Open(dialector.DSN, nil)
if err != nil {
return err
}

View File

@@ -1,11 +1,14 @@
package gormlite
import (
"context"
"fmt"
"testing"
"gorm.io/gorm"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -13,6 +16,17 @@ func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
// Custom connection with a custom function called "my_custom_function".
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText("my-result")
})
})
if err != nil {
t.Fatal(err)
}
rows := []struct {
description string
dialector *Dialector
@@ -29,6 +43,33 @@ func TestDialector(t *testing.T) {
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom function",
dialector: &Dialector{
DSN: InMemoryDSN,
},
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: false,
},
{
description: "Custom connection",
dialector: &Dialector{
Conn: conn,
},
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
{
description: "Custom connection, custom function",
dialector: &Dialector{
Conn: conn,
},
openSuccess: true,
query: "SELECT my_custom_function()",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {

View File

@@ -2,10 +2,9 @@ package tests
import (
"context"
"database/sql"
"testing"
_ "github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", ":memory:")
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}