mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
285 lines
5.6 KiB
Go
285 lines
5.6 KiB
Go
// Package driver provides a database/sql driver for SQLite.
|
|
package driver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/ncruces/go-sqlite3"
|
|
)
|
|
|
|
func init() {
|
|
sql.Register("sqlite3", sqlite{})
|
|
}
|
|
|
|
type sqlite struct{}
|
|
|
|
func (sqlite) Open(name string) (driver.Conn, error) {
|
|
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// If the database is not in WAL mode,
|
|
// use normal locking mode.
|
|
journal, err := pragma(c, "journal_mode")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if journal != "wal" {
|
|
pragma(c, "locking_mode=normal")
|
|
}
|
|
return conn{c}, nil
|
|
}
|
|
|
|
type conn struct{ conn *sqlite3.Conn }
|
|
|
|
var (
|
|
// Ensure these interfaces are implemented:
|
|
_ driver.Validator = conn{}
|
|
_ driver.ExecerContext = conn{}
|
|
// _ driver.ConnBeginTx = conn{}
|
|
// _ driver.SessionResetter = conn{}
|
|
)
|
|
|
|
func (c conn) Close() error {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
func (c conn) IsValid() bool {
|
|
// Pool only normal locking mode connections.
|
|
mode, _ := pragma(c.conn, "locking_mode")
|
|
return mode == "normal"
|
|
}
|
|
|
|
func (c conn) Begin() (driver.Tx, error) {
|
|
err := c.conn.Exec(`BEGIN`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (c conn) Commit() error {
|
|
err := c.conn.Exec(`COMMIT`)
|
|
if err != nil {
|
|
c.Rollback()
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (c conn) Rollback() error {
|
|
return c.conn.Exec(`ROLLBACK`)
|
|
}
|
|
|
|
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
|
s, tail, err := c.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if tail != "" {
|
|
// Check if the tail contains any SQL.
|
|
s, _, err := c.conn.Prepare(tail)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if s != nil {
|
|
s.Close()
|
|
return nil, tailErr
|
|
}
|
|
}
|
|
return stmt{s, c.conn}, nil
|
|
}
|
|
|
|
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
|
if len(args) != 0 {
|
|
// Slow path.
|
|
return nil, driver.ErrSkip
|
|
}
|
|
|
|
ch := c.conn.SetInterrupt(ctx.Done())
|
|
defer c.conn.SetInterrupt(ch)
|
|
|
|
err := c.conn.Exec(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result{
|
|
int64(c.conn.LastInsertRowID()),
|
|
int64(c.conn.Changes()),
|
|
}, nil
|
|
}
|
|
|
|
func pragma(c *sqlite3.Conn, pragma string) (string, error) {
|
|
stmt, _, err := c.Prepare(`PRAGMA ` + pragma)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer stmt.Close()
|
|
if stmt.Step() {
|
|
return stmt.ColumnText(0), nil
|
|
}
|
|
return "", stmt.Err()
|
|
}
|
|
|
|
type stmt struct {
|
|
stmt *sqlite3.Stmt
|
|
conn *sqlite3.Conn
|
|
}
|
|
|
|
var (
|
|
// Ensure these interfaces are implemented:
|
|
_ driver.StmtExecContext = stmt{}
|
|
_ driver.StmtQueryContext = stmt{}
|
|
)
|
|
|
|
func (s stmt) Close() error {
|
|
return s.stmt.Close()
|
|
}
|
|
|
|
func (s stmt) NumInput() int {
|
|
return s.stmt.BindCount()
|
|
}
|
|
|
|
// Deprecated: use ExecContext instead.
|
|
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
return s.ExecContext(context.Background(), namedValues(args))
|
|
}
|
|
|
|
// Deprecated: use QueryContext instead.
|
|
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
return s.QueryContext(context.Background(), namedValues(args))
|
|
}
|
|
|
|
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
|
_, err := s.QueryContext(ctx, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = s.stmt.Exec()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result{
|
|
int64(s.conn.LastInsertRowID()),
|
|
int64(s.conn.Changes()),
|
|
}, nil
|
|
}
|
|
|
|
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
err := s.stmt.ClearBindings()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var ids [3]int
|
|
for _, arg := range args {
|
|
ids := ids[:0]
|
|
if arg.Name == "" {
|
|
ids = append(ids, arg.Ordinal)
|
|
} else {
|
|
for _, prefix := range []string{":", "@", "$"} {
|
|
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, id := range ids {
|
|
switch a := arg.Value.(type) {
|
|
case bool:
|
|
err = s.stmt.BindBool(id, a)
|
|
case int64:
|
|
err = s.stmt.BindInt64(id, a)
|
|
case float64:
|
|
err = s.stmt.BindFloat(id, a)
|
|
case string:
|
|
err = s.stmt.BindText(id, a)
|
|
case []byte:
|
|
err = s.stmt.BindBlob(id, a)
|
|
case time.Time:
|
|
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
|
|
case nil:
|
|
err = s.stmt.BindNull(id)
|
|
default:
|
|
panic(assertErr)
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return rows{ctx, s.stmt, s.conn}, nil
|
|
}
|
|
|
|
type result struct{ lastInsertId, rowsAffected int64 }
|
|
|
|
func (r result) LastInsertId() (int64, error) {
|
|
return r.lastInsertId, nil
|
|
}
|
|
|
|
func (r result) RowsAffected() (int64, error) {
|
|
return r.rowsAffected, nil
|
|
}
|
|
|
|
type rows struct {
|
|
ctx context.Context
|
|
stmt *sqlite3.Stmt
|
|
conn *sqlite3.Conn
|
|
}
|
|
|
|
func (r rows) Close() error {
|
|
return r.stmt.Reset()
|
|
}
|
|
|
|
func (r rows) Columns() []string {
|
|
count := r.stmt.ColumnCount()
|
|
columns := make([]string, count)
|
|
for i := range columns {
|
|
columns[i] = r.stmt.ColumnName(i)
|
|
}
|
|
return columns
|
|
}
|
|
|
|
func (r rows) Next(dest []driver.Value) error {
|
|
ch := r.conn.SetInterrupt(r.ctx.Done())
|
|
defer r.conn.SetInterrupt(ch)
|
|
|
|
if !r.stmt.Step() {
|
|
if err := r.stmt.Err(); err != nil {
|
|
return err
|
|
}
|
|
return io.EOF
|
|
}
|
|
|
|
for i := range dest {
|
|
switch r.stmt.ColumnType(i) {
|
|
case sqlite3.INTEGER:
|
|
dest[i] = r.stmt.ColumnInt64(i)
|
|
case sqlite3.FLOAT:
|
|
dest[i] = r.stmt.ColumnFloat(i)
|
|
case sqlite3.TEXT:
|
|
dest[i] = maybeDate(r.stmt.ColumnText(i))
|
|
case sqlite3.BLOB:
|
|
buf, _ := dest[i].([]byte)
|
|
dest[i] = r.stmt.ColumnBlob(i, buf)
|
|
case sqlite3.NULL:
|
|
if buf, ok := dest[i].([]byte); ok {
|
|
dest[i] = buf[0:0]
|
|
} else {
|
|
dest[i] = nil
|
|
}
|
|
default:
|
|
panic(assertErr)
|
|
}
|
|
}
|
|
|
|
return r.stmt.Err()
|
|
}
|