Driver connector.

This commit is contained in:
Nuno Cruces
2023-08-10 13:18:13 +01:00
parent f1e36e2581
commit 77f37893b9
2 changed files with 55 additions and 27 deletions

View File

@@ -46,9 +46,42 @@ func init() {
type sqlite struct{} type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) { func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite{}.OpenConnector(name)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
c := connector{name: name}
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
c.txlock = query.Get("_txlock")
c.pragmas = len(query["_pragma"]) > 0
}
}
return &c, nil
}
type connector struct {
name string
txlock string
pragmas bool
}
func (n *connector) Driver() driver.Driver {
return sqlite{}
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
var c conn var c conn
c.Conn, err = sqlite3.Open(name) c.Conn, err = sqlite3.Open(n.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,25 +91,18 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
} }
}() }()
var pragmas bool old := c.Conn.SetInterrupt(ctx)
c.txBegin = "BEGIN" defer c.Conn.SetInterrupt(old)
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, _ := url.ParseQuery(after)
switch s := query.Get("_txlock"); s { switch n.txlock {
case "": case "":
c.txBegin = "BEGIN" c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive": case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + s c.txBegin = "BEGIN " + n.txlock
default: default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s) return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
}
pragmas = len(query["_pragma"]) > 0
}
} }
if !pragmas { if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`) err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -256,11 +282,14 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
} }
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
err := s.setupBindings(ctx, args) err := s.setupBindings(args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
err = s.Stmt.Exec() err = s.Stmt.Exec()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -270,15 +299,14 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
} }
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.setupBindings(ctx, args) err := s.setupBindings(args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &rows{ctx, s.Stmt, s.Conn}, nil return &rows{ctx, s.Stmt, s.Conn}, nil
} }
func (s *stmt) setupBindings(ctx context.Context, args []driver.NamedValue) error { func (s *stmt) setupBindings(args []driver.NamedValue) error {
err := s.Stmt.ClearBindings() err := s.Stmt.ClearBindings()
if err != nil { if err != nil {
return err return err

View File

@@ -61,12 +61,12 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool { func (s *Stmt) Step() bool {
s.c.checkInterrupt() s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle)) r := s.c.call(s.c.api.step, uint64(s.handle))
if r == _ROW { switch r {
case _ROW:
return true return true
} case _DONE:
if r == _DONE {
s.err = nil s.err = nil
} else { default:
s.err = s.c.error(r) s.err = s.c.error(r)
} }
return false return false