diff --git a/driver/driver.go b/driver/driver.go index 4dc8182..e89adc7 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -46,9 +46,42 @@ func init() { type sqlite struct{} -func (sqlite) Open(name string) (_ driver.Conn, err error) { +func (sqlite) Open(name string) (driver.Conn, error) { + c, err := sqlite{}.OpenConnector(name) + if err != nil { + return nil, err + } + return c.Connect(context.Background()) +} + +func (sqlite) OpenConnector(name string) (driver.Connector, error) { + c := connector{name: name} + if strings.HasPrefix(name, "file:") { + if _, after, ok := strings.Cut(name, "?"); ok { + query, err := url.ParseQuery(after) + if err != nil { + return nil, err + } + c.txlock = query.Get("_txlock") + c.pragmas = len(query["_pragma"]) > 0 + } + } + return &c, nil +} + +type connector struct { + name string + txlock string + pragmas bool +} + +func (n *connector) Driver() driver.Driver { + return sqlite{} +} + +func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { var c conn - c.Conn, err = sqlite3.Open(name) + c.Conn, err = sqlite3.Open(n.name) if err != nil { return nil, err } @@ -58,25 +91,18 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { } }() - var pragmas bool - c.txBegin = "BEGIN" - if strings.HasPrefix(name, "file:") { - if _, after, ok := strings.Cut(name, "?"); ok { - query, _ := url.ParseQuery(after) + old := c.Conn.SetInterrupt(ctx) + defer c.Conn.SetInterrupt(old) - switch s := query.Get("_txlock"); s { - case "": - c.txBegin = "BEGIN" - case "deferred", "immediate", "exclusive": - c.txBegin = "BEGIN " + s - default: - return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s) - } - - pragmas = len(query["_pragma"]) > 0 - } + switch n.txlock { + case "": + c.txBegin = "BEGIN" + case "deferred", "immediate", "exclusive": + c.txBegin = "BEGIN " + n.txlock + default: + return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock) } - if !pragmas { + if !n.pragmas { err = c.Conn.Exec(`PRAGMA busy_timeout=60000`) if err != nil { return nil, err @@ -256,11 +282,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { } func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - err := s.setupBindings(ctx, args) + err := s.setupBindings(args) if err != nil { return nil, err } + old := s.Conn.SetInterrupt(ctx) + defer s.Conn.SetInterrupt(old) + err = s.Stmt.Exec() if err != nil { return nil, err @@ -270,15 +299,14 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - err := s.setupBindings(ctx, args) + err := s.setupBindings(args) if err != nil { return nil, err } - return &rows{ctx, s.Stmt, s.Conn}, nil } -func (s *stmt) setupBindings(ctx context.Context, args []driver.NamedValue) error { +func (s *stmt) setupBindings(args []driver.NamedValue) error { err := s.Stmt.ClearBindings() if err != nil { return err diff --git a/stmt.go b/stmt.go index c26de44..b944e5d 100644 --- a/stmt.go +++ b/stmt.go @@ -61,12 +61,12 @@ func (s *Stmt) ClearBindings() error { func (s *Stmt) Step() bool { s.c.checkInterrupt() r := s.c.call(s.c.api.step, uint64(s.handle)) - if r == _ROW { + switch r { + case _ROW: return true - } - if r == _DONE { + case _DONE: s.err = nil - } else { + default: s.err = s.c.error(r) } return false