Driver API.

This commit is contained in:
Nuno Cruces
2023-09-20 02:41:09 +01:00
parent ac6dd1aa5f
commit d9fcf60b7d
6 changed files with 96 additions and 27 deletions

View File

@@ -40,14 +40,34 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
var driverName = "sqlite3"
func init() {
sql.Register("sqlite3", sqlite{})
if driverName != "" {
sql.Register(driverName, sqlite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to database/sql.
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite{}.OpenConnector(name)
c, err := newConnector(name, nil)
if err != nil {
return nil, err
}
@@ -55,7 +75,11 @@ func (sqlite) Open(name string) (driver.Conn, error) {
}
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
c := connector{name: name}
return newConnector(name, nil)
}
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
@@ -73,6 +97,7 @@ type connector struct {
name string
txlock string
pragmas bool
init func(ctx context.Context, conn *sqlite3.Conn) error
}
func (n *connector) Driver() driver.Driver {
@@ -126,6 +151,12 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
return nil, err
}
}
if n.init != nil {
err = n.init(ctx, c.Conn)
if err != nil {
return nil, err
}
}
return &c, nil
}
@@ -140,12 +171,17 @@ type conn struct {
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
_ driver.ConnPrepareContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
func (c *conn) IsValid() bool {
return c.reusable
}
@@ -190,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err