Files
sqlite3/driver/driver.go

254 lines
5.0 KiB
Go
Raw Normal View History

2023-02-14 18:21:18 +00:00
// Package driver provides a database/sql driver for SQLite.
package driver
import (
2023-02-18 02:16:11 +00:00
"context"
2023-02-14 18:21:18 +00:00
"database/sql"
"database/sql/driver"
2023-02-17 02:21:07 +00:00
"io"
"time"
2023-02-14 18:21:18 +00:00
"github.com/ncruces/go-sqlite3"
)
func init() {
sql.Register("sqlite3", sqlite{})
}
type sqlite struct{}
func (sqlite) Open(name string) (driver.Conn, error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI)
if err != nil {
return nil, err
}
2023-02-17 16:19:55 +00:00
// If the database is not in WAL mode,
// use normal locking mode.
journal, err := pragma(c, "journal_mode")
if err != nil {
return nil, err
}
if journal != "wal" {
pragma(c, "locking_mode=normal")
}
2023-02-14 18:21:18 +00:00
return conn{c}, nil
}
2023-02-17 02:21:07 +00:00
type conn struct{ conn *sqlite3.Conn }
var (
2023-02-17 16:19:55 +00:00
// Ensure these interfaces are implemented:
_ driver.Validator = conn{}
// _ driver.SessionResetter = conn{}
// _ driver.ExecerContext = conn{}
// _ driver.ConnBeginTx = conn{}
2023-02-17 02:21:07 +00:00
)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
2023-02-17 16:19:55 +00:00
// Pool only normal locking mode connections.
mode, _ := pragma(c.conn, "locking_mode")
return mode == "normal"
2023-02-17 02:21:07 +00:00
}
func (c conn) Begin() (driver.Tx, error) {
err := c.conn.Exec(`BEGIN`)
if err != nil {
return nil, err
}
return c, nil
}
func (c conn) Commit() error {
err := c.conn.Exec(`COMMIT`)
if err != nil {
c.Rollback()
}
return err
}
2023-02-14 18:21:18 +00:00
2023-02-17 02:21:07 +00:00
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
}
2023-02-14 18:21:18 +00:00
func (c conn) Prepare(query string) (driver.Stmt, error) {
2023-02-17 02:21:07 +00:00
s, _, err := c.conn.Prepare(query)
if err != nil {
return nil, err
}
return stmt{s, c.conn}, nil
}
2023-02-17 16:19:55 +00:00
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
if err != nil {
return "", err
}
defer stmt.Close()
if stmt.Step() {
return stmt.ColumnText(0), nil
}
return "", stmt.Err()
}
2023-02-17 02:21:07 +00:00
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
2023-02-17 16:19:55 +00:00
var (
// Ensure these interfaces are implemented:
2023-02-18 02:16:11 +00:00
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
2023-02-17 16:19:55 +00:00
)
2023-02-17 02:21:07 +00:00
func (s stmt) Close() error {
return s.stmt.Close()
}
func (s stmt) NumInput() int {
return s.stmt.BindCount()
}
2023-02-18 02:16:11 +00:00
// Deprecated: use ExecContext instead.
2023-02-17 02:21:07 +00:00
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
2023-02-18 02:16:11 +00:00
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
_, err := s.QueryContext(ctx, args)
2023-02-14 18:21:18 +00:00
if err != nil {
return nil, err
}
2023-02-17 02:21:07 +00:00
err = s.stmt.Exec()
if err != nil {
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
}
2023-02-18 02:16:11 +00:00
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
if err != nil {
return nil, err
}
var ids [3]int
for _, arg := range args {
ids := ids[:0]
if arg.Name == "" {
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
}
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
}
2023-02-17 02:21:07 +00:00
}
if err != nil {
return nil, err
}
}
2023-02-18 02:16:11 +00:00
return rows{ctx, s.stmt, s.conn}, nil
2023-02-17 02:21:07 +00:00
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
return r.lastInsertId, nil
}
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
2023-02-18 02:16:11 +00:00
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
2023-02-17 02:21:07 +00:00
func (r rows) Close() error {
2023-02-18 02:16:11 +00:00
return r.stmt.Reset()
2023-02-17 02:21:07 +00:00
}
func (r rows) Columns() []string {
2023-02-18 02:16:11 +00:00
count := r.stmt.ColumnCount()
2023-02-17 02:21:07 +00:00
columns := make([]string, count)
for i := range columns {
2023-02-18 02:16:11 +00:00
columns[i] = r.stmt.ColumnName(i)
2023-02-17 02:21:07 +00:00
}
return columns
2023-02-14 18:21:18 +00:00
}
2023-02-17 02:21:07 +00:00
func (r rows) Next(dest []driver.Value) error {
2023-02-18 02:16:11 +00:00
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
2023-02-17 12:30:07 +00:00
return err
2023-02-17 02:21:07 +00:00
}
2023-02-17 12:30:07 +00:00
return io.EOF
2023-02-17 02:21:07 +00:00
}
2023-02-14 18:21:18 +00:00
2023-02-17 02:21:07 +00:00
for i := range dest {
2023-02-18 02:16:11 +00:00
switch r.stmt.ColumnType(i) {
2023-02-17 02:21:07 +00:00
case sqlite3.INTEGER:
2023-02-18 02:16:11 +00:00
dest[i] = r.stmt.ColumnInt64(i)
2023-02-17 02:21:07 +00:00
case sqlite3.FLOAT:
2023-02-18 02:16:11 +00:00
dest[i] = r.stmt.ColumnFloat(i)
2023-02-17 02:21:07 +00:00
case sqlite3.TEXT:
2023-02-18 02:16:11 +00:00
dest[i] = maybeDate(r.stmt.ColumnText(i))
2023-02-17 02:21:07 +00:00
case sqlite3.BLOB:
2023-02-17 12:30:07 +00:00
buf, _ := dest[i].([]byte)
2023-02-18 02:16:11 +00:00
dest[i] = r.stmt.ColumnBlob(i, buf)
2023-02-17 10:40:43 +00:00
case sqlite3.NULL:
2023-02-17 12:30:07 +00:00
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
} else {
dest[i] = nil
}
2023-02-17 10:40:43 +00:00
default:
panic(assertErr)
2023-02-17 02:21:07 +00:00
}
}
2023-02-14 18:21:18 +00:00
2023-02-18 02:16:11 +00:00
return r.stmt.Err()
2023-02-17 02:21:07 +00:00
}