mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Driver API.
This commit is contained in:
18
conn.go
18
conn.go
@@ -2,7 +2,6 @@ package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/interrupt.html
|
||||
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
// Is it the same context?
|
||||
if ctx == c.interrupt {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Is a waiter running?
|
||||
if c.waiter != nil {
|
||||
c.waiter <- struct{}{} // Cancel the waiter.
|
||||
@@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
driver.Conn
|
||||
driver.ConnBeginTx
|
||||
driver.ExecerContext
|
||||
driver.ConnPrepareContext
|
||||
|
||||
SetInterrupt(ctx context.Context) (old context.Context)
|
||||
|
||||
Savepoint() Savepoint
|
||||
Backup(srcDB, dstURI string) error
|
||||
Restore(dstDB, srcURI string) error
|
||||
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
|
||||
Raw() *Conn
|
||||
}
|
||||
|
||||
@@ -40,14 +40,34 @@ import (
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// This variable can be replaced with -ldflags:
|
||||
//
|
||||
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
|
||||
var driverName = "sqlite3"
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", sqlite{})
|
||||
if driverName != "" {
|
||||
sql.Register(driverName, sqlite{})
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
|
||||
//
|
||||
// The init function is called by the driver on new connections.
|
||||
// The conn can be used to execute queries, register functions, etc.
|
||||
// Any error return closes the conn and passes the error to database/sql.
|
||||
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
|
||||
c, err := newConnector(dataSourceName, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.OpenDB(c), nil
|
||||
}
|
||||
|
||||
type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
c, err := sqlite{}.OpenConnector(name)
|
||||
c, err := newConnector(name, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -55,7 +75,11 @@ func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
}
|
||||
|
||||
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
|
||||
c := connector{name: name}
|
||||
return newConnector(name, nil)
|
||||
}
|
||||
|
||||
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
|
||||
c := connector{name: name, init: init}
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, err := url.ParseQuery(after)
|
||||
@@ -73,6 +97,7 @@ type connector struct {
|
||||
name string
|
||||
txlock string
|
||||
pragmas bool
|
||||
init func(ctx context.Context, conn *sqlite3.Conn) error
|
||||
}
|
||||
|
||||
func (n *connector) Driver() driver.Driver {
|
||||
@@ -126,6 +151,12 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if n.init != nil {
|
||||
err = n.init(ctx, c.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
@@ -140,12 +171,17 @@ type conn struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
_ driver.ConnPrepareContext = &conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c *conn) Raw() *sqlite3.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
@@ -190,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.GetAutocommit() {
|
||||
if err != nil && !c.Conn.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
|
||||
@@ -47,7 +47,7 @@ func ExampleDriverConn() {
|
||||
}
|
||||
|
||||
err = conn.Raw(func(driverConn any) error {
|
||||
conn := driverConn.(sqlite3.DriverConn)
|
||||
conn := driverConn.(sqlite3.DriverConn).Raw()
|
||||
savept := conn.Savepoint()
|
||||
defer savept.Release(&err)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
@@ -33,7 +32,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := sql.Open("sqlite3", dialector.DSN)
|
||||
conn, err := driver.Open(dialector.DSN, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
@@ -13,6 +16,17 @@ func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
// Custom connection with a custom function called "my_custom_function".
|
||||
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
|
||||
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
|
||||
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultText("my-result")
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
@@ -29,6 +43,33 @@ func TestDialector(t *testing.T) {
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom function",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: false,
|
||||
},
|
||||
{
|
||||
description: "Custom connection",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
{
|
||||
description: "Custom connection, custom function",
|
||||
dialector: &Dialector{
|
||||
Conn: conn,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT my_custom_function()",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
|
||||
@@ -2,10 +2,9 @@ package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
db, err := driver.Open(":memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user