From 28cb558d106b6ce03f02f8fe6f40db76d40a6a7c Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 17 Feb 2023 02:21:07 +0000 Subject: [PATCH] Minimal database/sql driver. --- README.md | 6 +- conn.go | 16 +++++ driver/driver.go | 155 ++++++++++++++++++++++++++++++++++++++--- driver/example_test.go | 153 ++++++++++++++++++++++++++++++++++++++++ stmt.go | 30 ++++++++ 5 files changed, 347 insertions(+), 13 deletions(-) create mode 100644 driver/example_test.go diff --git a/README.md b/README.md index ccc93b7..3bca6e9 100644 --- a/README.md +++ b/README.md @@ -14,5 +14,7 @@ Roadmap: - [x] `:memory:` databases - [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go - branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly -- [x] come up with a simple, nice API, enough for simple queries -- [x] file locking, compatible with SQLite on Windows/Unix \ No newline at end of file +- [x] design a simple, nice API, enough for simple use cases +- [x] minimal `database/sql` driver +- [x] file locking, compatible with SQLite on Windows/Unix +- [ ] shared memory, compatible with SQLite on Windows/Unix \ No newline at end of file diff --git a/conn.go b/conn.go index afdb85c..95f425d 100644 --- a/conn.go +++ b/conn.go @@ -205,6 +205,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str return } +func (c *Conn) LastInsertRowID() uint64 { + r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return r[0] +} + +func (c *Conn) Changes() uint64 { + r, err := c.api.changes.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return r[0] +} + func (c *Conn) error(rc uint64, sql ...string) error { if rc == _OK { return nil diff --git a/driver/driver.go b/driver/driver.go index 1578975..9848d81 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -1,11 +1,12 @@ -//go:build todo - // Package driver provides a database/sql driver for SQLite. package driver import ( + "context" "database/sql" "database/sql/driver" + "io" + "time" "github.com/ncruces/go-sqlite3" ) @@ -24,21 +25,153 @@ func (sqlite) Open(name string) (driver.Conn, error) { return conn{c}, nil } -type conn struct{ *sqlite3.Conn } -type stmt struct{ *sqlite3.Stmt } +type conn struct{ conn *sqlite3.Conn } -func (c conn) Begin() (driver.Tx, error) +var ( + _ driver.Validator = conn{} + _ driver.SessionResetter = conn{} +) -func (c conn) Prepare(query string) (driver.Stmt, error) { - s, _, err := c.Conn.Prepare(query) +func (c conn) Close() error { + return c.conn.Close() +} + +func (c conn) IsValid() bool { + return false +} + +func (c conn) ResetSession(ctx context.Context) error { + return driver.ErrBadConn +} + +func (c conn) Begin() (driver.Tx, error) { + err := c.conn.Exec(`BEGIN`) if err != nil { return nil, err } - return stmt{s}, nil + return c, nil } -func (s stmt) NumInput() int +func (c conn) Commit() error { + err := c.conn.Exec(`COMMIT`) + if err != nil { + c.Rollback() + } + return err +} -func (s stmt) Exec(args []driver.Value) (driver.Result, error) +func (c conn) Rollback() error { + return c.conn.Exec(`ROLLBACK`) +} -func (s stmt) Query(args []driver.Value) (driver.Rows, error) +func (c conn) Prepare(query string) (driver.Stmt, error) { + s, _, err := c.conn.Prepare(query) + if err != nil { + return nil, err + } + return stmt{s, c.conn}, nil +} + +type stmt struct { + stmt *sqlite3.Stmt + conn *sqlite3.Conn +} + +func (s stmt) Close() error { + return s.stmt.Close() +} + +func (s stmt) NumInput() int { + return s.stmt.BindCount() +} + +func (s stmt) Exec(args []driver.Value) (driver.Result, error) { + _, err := s.Query(args) + if err != nil { + return nil, err + } + + err = s.stmt.Exec() + if err != nil { + return nil, err + } + + return result{ + int64(s.conn.LastInsertRowID()), + int64(s.conn.Changes()), + }, nil +} + +func (s stmt) Query(args []driver.Value) (driver.Rows, error) { + var err error + for i, arg := range args { + switch a := arg.(type) { + case bool: + err = s.stmt.BindBool(i+1, a) + case int64: + err = s.stmt.BindInt64(i+1, a) + case float64: + err = s.stmt.BindFloat(i+1, a) + case string: + err = s.stmt.BindText(i+1, a) + case []byte: + err = s.stmt.BindBlob(i+1, a) + case time.Time: + err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano)) + } + if err != nil { + return nil, err + } + } + return rows{s.stmt}, nil +} + +type result struct{ lastInsertId, rowsAffected int64 } + +func (r result) LastInsertId() (int64, error) { + return r.lastInsertId, nil +} + +func (r result) RowsAffected() (int64, error) { + return r.rowsAffected, nil +} + +type rows struct{ s *sqlite3.Stmt } + +func (r rows) Close() error { + return r.s.Reset() +} + +func (r rows) Columns() []string { + count := r.s.ColumnCount() + columns := make([]string, count) + for i := range columns { + columns[i] = r.s.ColumnName(i) + } + return columns +} + +func (r rows) Next(dest []driver.Value) error { + if !r.s.Step() { + err := r.s.Err() + if err == nil { + return io.EOF + } + return err + } + + for i := range dest { + switch r.s.ColumnType(i) { + case sqlite3.INTEGER: + dest[i] = r.s.ColumnInt64(i) + case sqlite3.FLOAT: + dest[i] = r.s.ColumnFloat(i) + case sqlite3.TEXT: + dest[i] = r.s.ColumnText(i) + case sqlite3.BLOB: + dest[i] = r.s.ColumnBlob(i, nil) + } + } + + return r.s.Err() +} diff --git a/driver/example_test.go b/driver/example_test.go new file mode 100644 index 0000000..35be1ba --- /dev/null +++ b/driver/example_test.go @@ -0,0 +1,153 @@ +package driver_test + +import ( + "database/sql" + "fmt" + "log" + "os" + + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) + +var db *sql.DB + +func Example() { + // Adapted from: https://go.dev/doc/tutorial/database-access + + // Get a database handle. + var err error + db, err = sql.Open("sqlite3", "./recordings.db") + if err != nil { + log.Fatal(err) + } + defer db.Close() + defer os.Remove("./recordings.db") + + err = setupDatabase() + if err != nil { + log.Fatal(err) + } + + albums, err := albumsByArtist("John Coltrane") + if err != nil { + log.Fatal(err) + } + fmt.Printf("Albums found: %v\n", albums) + + // Hard-code ID 2 here to test the query. + alb, err := albumByID(2) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Album found: %v\n", alb) + + albID, err := addAlbum(Album{ + Title: "The Modern Sound of Betty Carter", + Artist: "Betty Carter", + Price: 49.99, + }) + if err != nil { + log.Fatal(err) + } + fmt.Printf("ID of added album: %v\n", albID) + + // Output: + // Albums found: [{1 Blue Train John Coltrane 56.99} {2 Giant Steps John Coltrane 63.99}] + // Album found: {2 Giant Steps John Coltrane 63.99} + // ID of added album: 5 +} + +type Album struct { + ID int64 + Title string + Artist string + Price float32 +} + +func setupDatabase() error { + _, err := db.Exec(`DROP TABLE IF EXISTS album`) + if err != nil { + return err + } + + _, err = db.Exec(` + CREATE TABLE album ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title VARCHAR(128) NOT NULL, + artist VARCHAR(255) NOT NULL, + price DECIMAL(5,2) NOT NULL + ) + `) + if err != nil { + return err + } + + _, err = db.Exec(` + INSERT INTO album + (title, artist, price) + VALUES + ('Blue Train', 'John Coltrane', 56.99), + ('Giant Steps', 'John Coltrane', 63.99), + ('Jeru', 'Gerry Mulligan', 17.99), + ('Sarah Vaughan', 'Sarah Vaughan', 34.98) + `) + if err != nil { + return err + } + + return nil +} + +// albumsByArtist queries for albums that have the specified artist name. +func albumsByArtist(name string) ([]Album, error) { + // An albums slice to hold data from returned rows. + var albums []Album + + rows, err := db.Query("SELECT * FROM album WHERE artist = ?", name) + if err != nil { + return nil, fmt.Errorf("albumsByArtist %q: %v", name, err) + } + defer rows.Close() + // Loop through rows, using Scan to assign column data to struct fields. + for rows.Next() { + var alb Album + if err := rows.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil { + return nil, fmt.Errorf("albumsByArtist %q: %v", name, err) + } + albums = append(albums, alb) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("albumsByArtist %q: %v", name, err) + } + return albums, nil +} + +// albumByID queries for the album with the specified ID. +func albumByID(id int64) (Album, error) { + // An album to hold data from the returned row. + var alb Album + + row := db.QueryRow("SELECT * FROM album WHERE id = ?", id) + if err := row.Scan(&alb.ID, &alb.Title, &alb.Artist, &alb.Price); err != nil { + if err == sql.ErrNoRows { + return alb, fmt.Errorf("albumsById %d: no such album", id) + } + return alb, fmt.Errorf("albumsById %d: %v", id, err) + } + return alb, nil +} + +// addAlbum adds the specified album to the database, +// returning the album ID of the new entry +func addAlbum(alb Album) (int64, error) { + result, err := db.Exec("INSERT INTO album (title, artist, price) VALUES (?, ?, ?)", alb.Title, alb.Artist, alb.Price) + if err != nil { + return 0, fmt.Errorf("addAlbum: %v", err) + } + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("addAlbum: %v", err) + } + return id, nil +} diff --git a/stmt.go b/stmt.go index 98850dc..029ee1f 100644 --- a/stmt.go +++ b/stmt.go @@ -199,6 +199,36 @@ func (s *Stmt) BindNull(param int) error { return s.c.error(r[0]) } +// ColumnCount returns the number of columns in a result set. +// +// https://www.sqlite.org/c3ref/column_count.html +func (s *Stmt) ColumnCount() int { + r, err := s.c.api.columnCount.Call(s.c.ctx, + uint64(s.handle)) + if err != nil { + panic(err) + } + return int(r[0]) +} + +// ColumnName returns the name of the result column. +// The leftmost column of the result set has the index 0. +// +// https://www.sqlite.org/c3ref/column_name.html +func (s *Stmt) ColumnName(col int) string { + r, err := s.c.api.columnName.Call(s.c.ctx, + uint64(s.handle), uint64(col)) + if err != nil { + panic(err) + } + + ptr := uint32(r[0]) + if ptr == 0 { + return "" + } + return s.c.mem.readString(ptr, 512) +} + // ColumnType returns the initial [Datatype] of the result column. // The leftmost column of the result set has the index 0. //