export the database/sql driver type and global instance

This commit is contained in:
kim
2024-04-19 09:32:20 +01:00
committed by Nuno Cruces
parent 62b79d2ac3
commit d4027b0133

View File

@@ -63,39 +63,52 @@ var driverName = "sqlite3"
func init() {
if driverName != "" {
sql.Register(driverName, sqlite{})
sql.Register(driverName, &SQLiteDriver)
}
}
// SQLiteDriver is a global Driver{} instance
// registered under [database/sql] as "sqlite3".
var SQLiteDriver = Driver{}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to [database/sql].
func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
c, err := newConnector(dataSourceName, init)
d := Driver{Init: init}
c, err := d.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
type sqlite struct{}
type Driver struct {
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := newConnector(name, nil)
// The init function is called by the driver on new connections.
// The conn can be used to execute queries, register functions, etc.
// Any error return closes the conn and passes the error to [database/sql].
Init func(*sqlite3.Conn) error
}
// Open: implements [database/sql.Driver].
func (d *Driver) Open(name string) (driver.Conn, error) {
c, err := d.newConnector(name)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
return newConnector(name, nil)
// OpenConnector: implements [database/sql.DriverContext].
func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
return d.newConnector(name)
}
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
func (d *Driver) newConnector(name string) (*connector, error) {
c := connector{driver: d, name: name}
var txlock, timefmt string
if strings.HasPrefix(name, "file:") {
@@ -137,7 +150,7 @@ func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, erro
}
type connector struct {
init func(*sqlite3.Conn) error
driver *Driver
name string
txBegin string
tmRead sqlite3.TimeFormat
@@ -146,7 +159,7 @@ type connector struct {
}
func (n *connector) Driver() driver.Driver {
return sqlite{}
return n.driver
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
@@ -175,13 +188,13 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
return nil, err
}
}
if n.init != nil {
err = n.init(c.Conn)
if n.driver.Init != nil {
err = n.driver.Init(c.Conn)
if err != nil {
return nil, err
}
}
if n.pragmas || n.init != nil {
if n.pragmas || n.driver.Init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err