Minimal database/sql driver.

This commit is contained in:
Nuno Cruces
2023-02-17 02:21:07 +00:00
parent 23806b0db1
commit 28cb558d10
5 changed files with 347 additions and 13 deletions

View File

@@ -14,5 +14,7 @@ Roadmap:
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] come up with a simple, nice API, enough for simple queries
- [x] file locking, compatible with SQLite on Windows/Unix
- [x] design a simple, nice API, enough for simple use cases
- [x] minimal `database/sql` driver
- [x] file locking, compatible with SQLite on Windows/Unix
- [ ] shared memory, compatible with SQLite on Windows/Unix

16
conn.go
View File

@@ -205,6 +205,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
return
}
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
func (c *Conn) error(rc uint64, sql ...string) error {
if rc == _OK {
return nil

View File

@@ -1,11 +1,12 @@
//go:build todo
// Package driver provides a database/sql driver for SQLite.
package driver
import (
"context"
"database/sql"
"database/sql/driver"
"io"
"time"
"github.com/ncruces/go-sqlite3"
)
@@ -24,21 +25,153 @@ func (sqlite) Open(name string) (driver.Conn, error) {
return conn{c}, nil
}
type conn struct{ *sqlite3.Conn }
type stmt struct{ *sqlite3.Stmt }
type conn struct{ conn *sqlite3.Conn }
func (c conn) Begin() (driver.Tx, error)
var (
_ driver.Validator = conn{}
_ driver.SessionResetter = conn{}
)
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, _, err := c.Conn.Prepare(query)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
return false
}
func (c conn) ResetSession(ctx context.Context) error {
return driver.ErrBadConn
}
func (c conn) Begin() (driver.Tx, error) {
err := c.conn.Exec(`BEGIN`)
if err != nil {
return nil, err
}
return stmt{s}, nil
return c, nil
}
func (s stmt) NumInput() int
func (c conn) Commit() error {
err := c.conn.Exec(`COMMIT`)
if err != nil {
c.Rollback()
}
return err
}
func (s stmt) Exec(args []driver.Value) (driver.Result, error)
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
}
func (s stmt) Query(args []driver.Value) (driver.Rows, error)
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, _, err := c.conn.Prepare(query)
if err != nil {
return nil, err
}
return stmt{s, c.conn}, nil
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
func (s stmt) Close() error {
return s.stmt.Close()
}
func (s stmt) NumInput() int {
return s.stmt.BindCount()
}
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
_, err := s.Query(args)
if err != nil {
return nil, err
}
err = s.stmt.Exec()
if err != nil {
return nil, err
}
return result{
int64(s.conn.LastInsertRowID()),
int64(s.conn.Changes()),
}, nil
}
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
var err error
for i, arg := range args {
switch a := arg.(type) {
case bool:
err = s.stmt.BindBool(i+1, a)
case int64:
err = s.stmt.BindInt64(i+1, a)
case float64:
err = s.stmt.BindFloat(i+1, a)
case string:
err = s.stmt.BindText(i+1, a)
case []byte:
err = s.stmt.BindBlob(i+1, a)
case time.Time:
err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano))
}
if err != nil {
return nil, err
}
}
return rows{s.stmt}, nil
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
return r.lastInsertId, nil
}
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type rows struct{ s *sqlite3.Stmt }
func (r rows) Close() error {
return r.s.Reset()
}
func (r rows) Columns() []string {
count := r.s.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.s.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
if !r.s.Step() {
err := r.s.Err()
if err == nil {
return io.EOF
}
return err
}
for i := range dest {
switch r.s.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.s.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.s.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = r.s.ColumnText(i)
case sqlite3.BLOB:
dest[i] = r.s.ColumnBlob(i, nil)
}
}
return r.s.Err()
}

153
driver/example_test.go Normal file
View File

@@ -0,0 +1,153 @@
package driver_test
import (
"database/sql"
"fmt"
"log"
"os"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
var db *sql.DB
func Example() {
// Adapted from: https://go.dev/doc/tutorial/database-access
// Get a database handle.
var err error
db, err = sql.Open("sqlite3", "./recordings.db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
defer os.Remove("./recordings.db")
err = setupDatabase()
if err != nil {
log.Fatal(err)
}
albums, err := albumsByArtist("John Coltrane")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Albums found: %v\n", albums)
// Hard-code ID 2 here to test the query.
alb, err := albumByID(2)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Album found: %v\n", alb)
albID, err := addAlbum(Album{
Title: "The Modern Sound of Betty Carter",
Artist: "Betty Carter",
Price: 49.99,
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("ID of added album: %v\n", albID)
// Output:
// Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}]
// Album found: {2 Giant Steps John Coltrane 63.99}
// ID of added album: 5
}
type Album struct {
ID int64
Title string
Artist string
Price float32
}
func setupDatabase() error {
_, err := db.Exec(`DROP TABLE IF EXISTS album`)
if err != nil {
return err
}
_, err = db.Exec(`
CREATE TABLE album (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title VARCHAR(128) NOT NULL,
artist VARCHAR(255) NOT NULL,
price DECIMAL(5,2) NOT NULL
)
`)
if err != nil {
return err
}
_, err = db.Exec(`
INSERT INTO album
(title, artist, price)
VALUES
('Blue Train', 'John Coltrane', 56.99),
('Giant Steps', 'John Coltrane', 63.99),
('Jeru', 'Gerry Mulligan', 17.99),
('Sarah Vaughan', 'Sarah Vaughan', 34.98)
`)
if err != nil {
return err
}
return nil
}
// albumsByArtist queries for albums that have the specified artist name.
func albumsByArtist(name string) ([]Album, error) {
// An albums slice to hold data from returned rows.
var albums []Album
rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name)
if err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
}
defer rows.Close()
// Loop through rows, using Scan to assign column data to struct fields.
for rows.Next() {
var alb Album
if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
}
albums = append(albums, alb)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("albumsByArtist %q: %v", name, err)
}
return albums, nil
}
// albumByID queries for the album with the specified ID.
func albumByID(id int64) (Album, error) {
// An album to hold data from the returned row.
var alb Album
row := db.QueryRow("SELECT * FROM album WHERE id = ?", id)
if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil {
if err == sql.ErrNoRows {
return alb, fmt.Errorf("albumsById %d: no such album", id)
}
return alb, fmt.Errorf("albumsById %d: %v", id, err)
}
return alb, nil
}
// addAlbum adds the specified album to the database,
// returning the album ID of the new entry
func addAlbum(alb Album) (int64, error) {
result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price)
if err != nil {
return 0, fmt.Errorf("addAlbum: %v", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("addAlbum: %v", err)
}
return id, nil
}

30
stmt.go
View File

@@ -199,6 +199,36 @@ func (s *Stmt) BindNull(param int) error {
return s.c.error(r[0])
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
r, err := s.c.api.columnCount.Call(s.c.ctx,
uint64(s.handle))
if err != nil {
panic(err)
}
return int(r[0])
}
// ColumnName returns the name of the result column.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
r, err := s.c.api.columnName.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
ptr := uint32(r[0])
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, 512)
}
// ColumnType returns the initial [Datatype] of the result column.
// The leftmost column of the result set has the index 0.
//