From a44690035f329d874c150cd55e4094126bd515e1 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 28 Nov 2024 00:12:11 +0000 Subject: [PATCH] Refactor. --- driver/driver.go | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index 88c4c50..150a01a 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -107,17 +107,17 @@ func init() { // The second callback is called before the driver closes a connection. // The [sqlite3.Conn] can be used to execute queries, register functions, etc. func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, error) { - var drv SQLite if len(fn) > 2 { return nil, sqlite3.MISUSE } + var init, term func(*sqlite3.Conn) error if len(fn) > 1 { - drv.term = fn[1] + term = fn[1] } if len(fn) > 0 { - drv.init = fn[0] + init = fn[0] } - c, err := drv.OpenConnector(dataSourceName) + c, err := newConnector(dataSourceName, init, term) if err != nil { return nil, err } @@ -125,10 +125,7 @@ func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, erro } // SQLite implements [database/sql/driver.Driver]. -type SQLite struct { - init func(*sqlite3.Conn) error - term func(*sqlite3.Conn) error -} +type SQLite struct{} var ( // Ensure these interfaces are implemented: @@ -137,7 +134,7 @@ var ( // Open implements [database/sql/driver.Driver]. func (d *SQLite) Open(name string) (driver.Conn, error) { - c, err := d.newConnector(name) + c, err := newConnector(name, nil, nil) if err != nil { return nil, err } @@ -146,11 +143,11 @@ func (d *SQLite) Open(name string) (driver.Conn, error) { // OpenConnector implements [database/sql/driver.DriverContext]. func (d *SQLite) OpenConnector(name string) (driver.Connector, error) { - return d.newConnector(name) + return newConnector(name, nil, nil) } -func (d *SQLite) newConnector(name string) (*connector, error) { - c := connector{driver: d, name: name} +func newConnector(name string, init, term func(*sqlite3.Conn) error) (*connector, error) { + c := connector{name: name, init: init, term: term} var txlock, timefmt string if strings.HasPrefix(name, "file:") { @@ -190,7 +187,8 @@ func (d *SQLite) newConnector(name string) (*connector, error) { } type connector struct { - driver *SQLite + init func(*sqlite3.Conn) error + term func(*sqlite3.Conn) error name string txLock string tmRead sqlite3.TimeFormat @@ -199,7 +197,7 @@ type connector struct { } func (n *connector) Driver() driver.Driver { - return n.driver + return &SQLite{} } func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { @@ -228,13 +226,13 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } } - if n.driver.init != nil { - err = n.driver.init(c.Conn) + if n.init != nil { + err = n.init(c.Conn) if err != nil { return nil, err } } - if n.pragmas || n.driver.init != nil { + if n.pragmas || n.init != nil { s, _, err := c.Conn.Prepare(`PRAGMA query_only`) if err != nil { return nil, err @@ -250,9 +248,9 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } } - if n.driver.term != nil { + if n.term != nil { err = c.Conn.Trace(sqlite3.TRACE_CLOSE, func(sqlite3.TraceEvent, any, any) error { - return n.driver.term(c.Conn) + return n.term(c.Conn) }) if err != nil { return nil, err @@ -288,6 +286,8 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { type Conn interface { Raw() *sqlite3.Conn driver.Conn + driver.ConnBeginTx + driver.ConnPrepareContext } type conn struct { @@ -301,10 +301,8 @@ type conn struct { var ( // Ensure these interfaces are implemented: - _ Conn = &conn{} - _ driver.ConnBeginTx = &conn{} - _ driver.ConnPrepareContext = &conn{} - _ driver.ExecerContext = &conn{} + _ Conn = &conn{} + _ driver.ExecerContext = &conn{} ) func (c *conn) Raw() *sqlite3.Conn {