diff --git a/README.md b/README.md index 1da5d16..8f0b869 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ and uses [wazero](https://wazero.io/) to provide `cgo`-free SQLite bindings. simplifies [incremental BLOB I/O](https://sqlite.org/c3ref/blob_open.html). - [`github.com/ncruces/go-sqlite3/ext/csv`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/csv) reads [comma-separated values](https://sqlite.org/csv.html). +- [`github.com/ncruces/go-sqlite3/ext/fileio`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/fileio) + reads and writes files. - [`github.com/ncruces/go-sqlite3/ext/lines`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/lines) reads files [line-by-line](https://github.com/asg017/sqlite-lines). - [`github.com/ncruces/go-sqlite3/ext/pivot`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/pivot) diff --git a/ext/array/array.go b/ext/array/array.go index a43c24d..f10c508 100644 --- a/ext/array/array.go +++ b/ext/array/array.go @@ -131,5 +131,5 @@ func indexable(v reflect.Value) (reflect.Value, error) { return v, nil } } - return v, fmt.Errorf("array: unsupported argument:%.0w %v", sqlite3.MISMATCH, v.Type()) + return v, fmt.Errorf("array: unsupported argument:%.0w %v", sqlite3.MISMATCH, v) } diff --git a/ext/array/array_test.go b/ext/array/array_test.go index 51b406d..a4cb556 100644 --- a/ext/array/array_test.go +++ b/ext/array/array_test.go @@ -92,3 +92,29 @@ func Test_cursor_Column(t *testing.T) { log.Fatal(err) } } + +func Test_array_errors(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + array.Register(db) + + err = db.Exec(`SELECT * FROM array()`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + err = db.Exec(`SELECT * FROM array(?)`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } +} diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index 311dc9d..78b0330 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -21,16 +21,16 @@ func Register(db *sqlite3.Conn) { // Register registers SQL functions readfile, lsmode, // and the eponymous virtual table fsdir; // fs will be used to read files and list directories. -func RegisterFS(db *sqlite3.Conn, fs fs.FS) { +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { db.CreateFunction("lsmode", 1, 0, lsmode) - db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fs)) - if fs == nil { + db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)) + if fsys == nil { db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile) } sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, module, schema, table string, arg ...string) (fsdir, error) { err := db.DeclareVtab(`CREATE TABLE x(name,mode,mtime,data,path HIDDEN,dir HIDDEN)`) db.VtabConfig(sqlite3.VTAB_DIRECTONLY) - return fsdir{fs}, err + return fsdir{fsys}, err }) } @@ -38,13 +38,13 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) { ctx.ResultText(fs.FileMode(arg[0].Int()).String()) } -func readfile(f fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) { +func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) { return func(ctx sqlite3.Context, arg ...sqlite3.Value) { var err error var data []byte - if f != nil { - data, err = fs.ReadFile(f, arg[0].Text()) + if fsys != nil { + data, err = fs.ReadFile(fsys, arg[0].Text()) } else { data, err = os.ReadFile(arg[0].Text()) } diff --git a/ext/fileio/fileio_test.go b/ext/fileio/fileio_test.go index 9b50049..2b777a5 100644 --- a/ext/fileio/fileio_test.go +++ b/ext/fileio/fileio_test.go @@ -1,6 +1,9 @@ package fileio_test import ( + "bytes" + "database/sql" + "io/fs" "os" "testing" @@ -44,3 +47,34 @@ func Test_lsmode(t *testing.T) { t.Logf("got %s", mode) } } + +func Test_readfile(t *testing.T) { + t.Parallel() + + for _, fsys := range []fs.FS{nil, os.DirFS(".")} { + t.Run("", func(t *testing.T) { + db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { + fileio.RegisterFS(c, fsys) + return nil + }) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + rows, err := db.Query(`SELECT readfile('fileio_test.go')`) + if err != nil { + t.Fatal(err) + } + + if rows.Next() { + var data sql.RawBytes + rows.Scan(&data) + + if !bytes.HasPrefix(data, []byte("package fileio_test")) { + t.Errorf("got %s", data[:min(64, len(data))]) + } + } + }) + } +} diff --git a/ext/fileio/fsdir.go b/ext/fileio/fsdir.go index d8798fc..bd78907 100644 --- a/ext/fileio/fsdir.go +++ b/ext/fileio/fsdir.go @@ -1,7 +1,6 @@ package fileio import ( - "fmt" "io/fs" "os" "path" @@ -11,7 +10,7 @@ import ( "github.com/ncruces/go-sqlite3" ) -type fsdir struct{ fs.FS } +type fsdir struct{ fsys fs.FS } func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error { var root, base bool @@ -37,21 +36,23 @@ func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error { base = true } } - if root { - idx.EstimatedCost = 100 + if !root { + return sqlite3.CONSTRAINT } if base { idx.EstimatedCost = 10 + } else { + idx.EstimatedCost = 100 } return nil } func (d fsdir) Open() (sqlite3.VTabCursor, error) { - return &cursor{fs: d.FS}, nil + return &cursor{fsys: d.fsys}, nil } type cursor struct { - fs fs.FS + fsys fs.FS base string rowID int64 eof bool @@ -81,14 +82,11 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { if err := c.Close(); err != nil { return err } - if len(arg) == 0 { - return fmt.Errorf("fsdir: wrong number of arguments") - } root := arg[0].Text() if len(arg) > 1 { base := arg[1].Text() - if c.fs != nil { + if c.fsys != nil { root = path.Join(base, root) base = path.Clean(base) + "/" } else { @@ -147,8 +145,8 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { case typ.IsRegular(): var data []byte var err error - if c.fs != nil { - data, err = fs.ReadFile(c.fs, c.curr.path) + if c.fsys != nil { + data, err = fs.ReadFile(c.fsys, c.curr.path) } else { data, err = os.ReadFile(c.curr.path) } @@ -157,7 +155,7 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { } ctx.ResultBlob(data) - case typ&fs.ModeSymlink != 0 && c.fs == nil: + case typ&fs.ModeSymlink != 0 && c.fsys == nil: t, err := os.Readlink(c.curr.path) if err != nil { return err @@ -169,26 +167,12 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { } func (c *cursor) WalkDir(path string) { - var err error + defer close(c.next) - defer func() { - if p := recover(); p != nil { - if perr, ok := p.(error); ok { - err = fmt.Errorf("panic: %w", perr) - } else { - err = fmt.Errorf("panic: %v", p) - } - } - if err != nil { - c.next <- entry{err: err} - } - close(c.next) - }() - - if c.fs != nil { - err = fs.WalkDir(c.fs, path, c.WalkDirFunc) + if c.fsys != nil { + fs.WalkDir(c.fsys, path, c.WalkDirFunc) } else { - err = filepath.WalkDir(path, c.WalkDirFunc) + filepath.WalkDir(path, c.WalkDirFunc) } } diff --git a/ext/fileio/fsdir_test.go b/ext/fileio/fsdir_test.go index 5e37bf7..1df5623 100644 --- a/ext/fileio/fsdir_test.go +++ b/ext/fileio/fsdir_test.go @@ -1,7 +1,10 @@ package fileio_test import ( + "bytes" + "database/sql" "io/fs" + "os" "testing" "time" @@ -14,34 +17,62 @@ import ( func Test_fsdir(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - fileio.Register(c) - return nil - }) + for _, fsys := range []fs.FS{nil, os.DirFS(".")} { + t.Run("", func(t *testing.T) { + db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { + fileio.RegisterFS(c, fsys) + return nil + }) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + rows, err := db.Query(`SELECT * FROM fsdir('.', '.') LIMIT 4`) + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + var name string + var mode fs.FileMode + var mtime time.Time + var data sql.RawBytes + err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime), &data) + if err != nil { + t.Fatal(err) + } + if mode.Perm() == 0 { + t.Errorf("got: %v", mode) + } + if mtime.Before(time.Unix(0, 0)) { + t.Errorf("got: %v", mtime) + } + if name == "fsdir_test.go" { + if !bytes.HasPrefix(data, []byte("package fileio_test")) { + t.Errorf("got: %s", data[:min(64, len(data))]) + } + } + } + }) + } +} + +func Test_fsdir_errors(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") if err != nil { t.Fatal(err) } defer db.Close() - rows, err := db.Query(`SELECT name, mode, mtime FROM fsdir('.')`) - if err != nil { - t.Fatal(err) - } + fileio.Register(db) - for rows.Next() { - var name string - var mode fs.FileMode - var mtime time.Time - err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime)) - if err != nil { - t.Fatal(err) - } - if mode.Perm() == 0 { - t.Errorf("mode %v", mode) - } - if mtime.Before(time.Unix(0, 0)) { - t.Errorf("mtime %v", mtime) - } - t.Log(name) + err = db.Exec(`SELECT name FROM fsdir()`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) } } diff --git a/ext/fileio/write_test.go b/ext/fileio/write_test.go new file mode 100644 index 0000000..f39d304 --- /dev/null +++ b/ext/fileio/write_test.go @@ -0,0 +1,115 @@ +package fileio + +import ( + "database/sql" + "io/fs" + "path/filepath" + "testing" + "time" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func Test_writefile(t *testing.T) { + t.Parallel() + + db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { + Register(c) + return nil + }) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + dir := t.TempDir() + link := filepath.Join(dir, "link") + file := filepath.Join(dir, "test.txt") + nest := filepath.Join(dir, "tmp", "test.txt") + sock := filepath.Join(dir, "sock") + twosday := time.Date(2022, 2, 22, 22, 22, 22, 0, time.UTC) + + _, err = db.Exec(`SELECT writefile(?, 'Hello world!')`, file) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec(`SELECT writefile(?, ?, ?)`, link, "test.txt", fs.ModeSymlink) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec(`SELECT writefile(?, ?, ?, ?)`, dir, nil, 0040700, twosday.Unix()) + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query(`SELECT * FROM fsdir('.', ?)`, dir) + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + var name string + var mode fs.FileMode + var mtime time.Time + var data sql.NullString + err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime), &data) + if err != nil { + t.Fatal(err) + } + if mode.IsDir() && mtime != twosday { + t.Errorf("got: %v", mtime) + } + if mode.IsRegular() && data.String != "Hello world!" { + t.Errorf("got: %v", data) + } + if mode&fs.ModeSymlink != 0 && data.String != "test.txt" { + t.Errorf("got: %v", data) + } + } + + _, err = db.Exec(`SELECT writefile(?, 'Hello world!')`, nest) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec(`SELECT writefile(?, ?, ?)`, sock, nil, fs.ModeSocket) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } + + _, err = db.Exec(`SELECT writefile()`) + if err == nil { + t.Fatal("want error") + } else { + t.Log(err) + } +} + +func Test_fixMode(t *testing.T) { + tests := []struct { + mode fs.FileMode + want fs.FileMode + }{ + {0010754, 0754 | fs.ModeNamedPipe}, + {0020754, 0754 | fs.ModeCharDevice | fs.ModeDevice}, + {0040754, 0754 | fs.ModeDir}, + {0060754, 0754 | fs.ModeDevice}, + {0100754, 0754}, + {0120754, 0754 | fs.ModeSymlink}, + {0140754, 0754 | fs.ModeSocket}, + {0170754, 0754 | fs.ModeIrregular}, + } + for _, tt := range tests { + t.Run(tt.mode.String(), func(t *testing.T) { + if got := fixMode(tt.mode); got != tt.want { + t.Errorf("fixMode() = %o, want %o", got, tt.want) + } + }) + } +}