mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Minimal database/sql driver.
This commit is contained in:
@@ -14,5 +14,7 @@ Roadmap:
|
||||
- [x] `:memory:` databases
|
||||
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
|
||||
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
|
||||
- [x] come up with a simple, nice API, enough for simple queries
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [x] design a simple, nice API, enough for simple use cases
|
||||
- [x] minimal `database/sql` driver
|
||||
- [x] file locking, compatible with SQLite on Windows/Unix
|
||||
- [ ] shared memory, compatible with SQLite on Windows/Unix
|
||||
16
conn.go
16
conn.go
@@ -205,6 +205,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) LastInsertRowID() uint64 {
|
||||
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r[0]
|
||||
}
|
||||
|
||||
func (c *Conn) Changes() uint64 {
|
||||
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r[0]
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
|
||||
155
driver/driver.go
155
driver/driver.go
@@ -1,11 +1,12 @@
|
||||
//go:build todo
|
||||
|
||||
// Package driver provides a database/sql driver for SQLite.
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
@@ -24,21 +25,153 @@ func (sqlite) Open(name string) (driver.Conn, error) {
|
||||
return conn{c}, nil
|
||||
}
|
||||
|
||||
type conn struct{ *sqlite3.Conn }
|
||||
type stmt struct{ *sqlite3.Stmt }
|
||||
type conn struct{ conn *sqlite3.Conn }
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error)
|
||||
var (
|
||||
_ driver.Validator = conn{}
|
||||
_ driver.SessionResetter = conn{}
|
||||
)
|
||||
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, _, err := c.Conn.Prepare(query)
|
||||
func (c conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c conn) ResetSession(ctx context.Context) error {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
err := c.conn.Exec(`BEGIN`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stmt{s}, nil
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int
|
||||
func (c conn) Commit() error {
|
||||
err := c.conn.Exec(`COMMIT`)
|
||||
if err != nil {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error)
|
||||
func (c conn) Rollback() error {
|
||||
return c.conn.Exec(`ROLLBACK`)
|
||||
}
|
||||
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error)
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, _, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stmt{s, c.conn}, nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return s.stmt.BindCount()
|
||||
}
|
||||
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
_, err := s.Query(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(s.conn.LastInsertRowID()),
|
||||
int64(s.conn.Changes()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
var err error
|
||||
for i, arg := range args {
|
||||
switch a := arg.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(i+1, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(i+1, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(i+1, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(i+1, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(i+1, a)
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rows{s.stmt}, nil
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
|
||||
func (r result) LastInsertId() (int64, error) {
|
||||
return r.lastInsertId, nil
|
||||
}
|
||||
|
||||
func (r result) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type rows struct{ s *sqlite3.Stmt }
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.s.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.s.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.s.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
if !r.s.Step() {
|
||||
err := r.s.Err()
|
||||
if err == nil {
|
||||
return io.EOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.s.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.s.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.s.ColumnFloat(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = r.s.ColumnText(i)
|
||||
case sqlite3.BLOB:
|
||||
dest[i] = r.s.ColumnBlob(i, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return r.s.Err()
|
||||
}
|
||||
|
||||
153
driver/example_test.go
Normal file
153
driver/example_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func Example() {
|
||||
// Adapted from: https://go.dev/doc/tutorial/database-access
|
||||
|
||||
// Get a database handle.
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "./recordings.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
defer os.Remove("./recordings.db")
|
||||
|
||||
err = setupDatabase()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
albums, err := albumsByArtist("John Coltrane")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Albums found: %v\n", albums)
|
||||
|
||||
// Hard-code ID 2 here to test the query.
|
||||
alb, err := albumByID(2)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Album found: %v\n", alb)
|
||||
|
||||
albID, err := addAlbum(Album{
|
||||
Title: "The Modern Sound of Betty Carter",
|
||||
Artist: "Betty Carter",
|
||||
Price: 49.99,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("ID of added album: %v\n", albID)
|
||||
|
||||
// Output:
|
||||
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
|
||||
// Album found: {2 Giant Steps John Coltrane 63.99}
|
||||
// ID of added album: 5
|
||||
}
|
||||
|
||||
type Album struct {
|
||||
ID int64
|
||||
Title string
|
||||
Artist string
|
||||
Price float32
|
||||
}
|
||||
|
||||
func setupDatabase() error {
|
||||
_, err := db.Exec(`DROP TABLE IF EXISTS album`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE album (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title VARCHAR(128) NOT NULL,
|
||||
artist VARCHAR(255) NOT NULL,
|
||||
price DECIMAL(5,2) NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO album
|
||||
(title, artist, price)
|
||||
VALUES
|
||||
('Blue Train', 'John Coltrane', 56.99),
|
||||
('Giant Steps', 'John Coltrane', 63.99),
|
||||
('Jeru', 'Gerry Mulligan', 17.99),
|
||||
('Sarah Vaughan', 'Sarah Vaughan', 34.98)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// albumsByArtist queries for albums that have the specified artist name.
|
||||
func albumsByArtist(name string) ([]Album, error) {
|
||||
// An albums slice to hold data from returned rows.
|
||||
var albums []Album
|
||||
|
||||
rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
// Loop through rows, using Scan to assign column data to struct fields.
|
||||
for rows.Next() {
|
||||
var alb Album
|
||||
if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
|
||||
}
|
||||
albums = append(albums, alb)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
|
||||
}
|
||||
return albums, nil
|
||||
}
|
||||
|
||||
// albumByID queries for the album with the specified ID.
|
||||
func albumByID(id int64) (Album, error) {
|
||||
// An album to hold data from the returned row.
|
||||
var alb Album
|
||||
|
||||
row := db.QueryRow("SELECT * FROM album WHERE id = ?", id)
|
||||
if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return alb, fmt.Errorf("albumsById %d: no such album", id)
|
||||
}
|
||||
return alb, fmt.Errorf("albumsById %d: %v", id, err)
|
||||
}
|
||||
return alb, nil
|
||||
}
|
||||
|
||||
// addAlbum adds the specified album to the database,
|
||||
// returning the album ID of the new entry
|
||||
func addAlbum(alb Album) (int64, error) {
|
||||
result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("addAlbum: %v", err)
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("addAlbum: %v", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
30
stmt.go
30
stmt.go
@@ -199,6 +199,36 @@ func (s *Stmt) BindNull(param int) error {
|
||||
return s.c.error(r[0])
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_count.html
|
||||
func (s *Stmt) ColumnCount() int {
|
||||
r, err := s.c.api.columnCount.Call(s.c.ctx,
|
||||
uint64(s.handle))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(r[0])
|
||||
}
|
||||
|
||||
// ColumnName returns the name of the result column.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/column_name.html
|
||||
func (s *Stmt) ColumnName(col int) string {
|
||||
r, err := s.c.api.columnName.Call(s.c.ctx,
|
||||
uint64(s.handle), uint64(col))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
return s.c.mem.readString(ptr, 512)
|
||||
}
|
||||
|
||||
// ColumnType returns the initial [Datatype] of the result column.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user