diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 144c1d7..b554145 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -12,21 +12,21 @@ import ( "fmt" "io" "io/fs" - "os" "strings" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" ) // Register registers the CSV virtual table. -// If a filename is specified, `os.Open` is used to open the file. +// If a filename is specified, [os.Open] is used to open the file. func Register(db *sqlite3.Conn) { - RegisterOpen(db, osfs{}) + RegisterFS(db, util.OSFS{}) } -// RegisterOpen registers the CSV virtual table. +// RegisterFS registers the CSV virtual table. // If a filename is specified, fsys is used to open the file. -func RegisterOpen(db *sqlite3.Conn, fsys fs.FS) { +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { var ( filename string @@ -108,12 +108,6 @@ func RegisterOpen(db *sqlite3.Conn, fsys fs.FS) { sqlite3.CreateModule(db, "csv", declare, declare) } -type osfs struct{} - -func (osfs) Open(name string) (fs.File, error) { - return os.Open(name) -} - type table struct { fsys fs.FS name string diff --git a/ext/lines/lines.go b/ext/lines/lines.go index d0a1435..17d611b 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -15,30 +15,42 @@ import ( "bytes" "fmt" "io" - "os" + "io/fs" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" ) // Register registers the lines and lines_read virtual tables. // The lines virtual table reads from a database blob or text. // The lines_read virtual table reads from a file or an [io.Reader]. +// If a filename is specified, [os.Open] is used to open the file. func Register(db *sqlite3.Conn) { + RegisterFS(db, util.OSFS{}) +} + +// RegisterFS registers the lines and lines_read virtual tables. +// The lines virtual table reads from a database blob or text. +// The lines_read virtual table reads from a file or an [io.Reader]. +// If a filename is specified, fsys is used to open the file. +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { sqlite3.CreateModule[lines](db, "lines", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { err := db.DeclareVtab(`CREATE TABLE x(line TEXT, data HIDDEN)`) db.VtabConfig(sqlite3.VTAB_INNOCUOUS) - return false, err + return lines{}, err }) sqlite3.CreateModule[lines](db, "lines_read", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { err := db.DeclareVtab(`CREATE TABLE x(line TEXT, data HIDDEN)`) db.VtabConfig(sqlite3.VTAB_DIRECTONLY) - return true, err + return lines{fsys}, err }) } -type lines bool +type lines struct { + fsys fs.FS +} func (l lines) BestIndex(idx *sqlite3.IndexInfo) error { for i, cst := range idx.Constraint { @@ -56,8 +68,8 @@ func (l lines) BestIndex(idx *sqlite3.IndexInfo) error { } func (l lines) Open() (sqlite3.VTabCursor, error) { - if l { - return &reader{}, nil + if l.fsys != nil { + return &reader{fsys: l.fsys}, nil } else { return &buffer{}, nil } @@ -85,6 +97,7 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { } type reader struct { + fsys fs.FS reader *bufio.Reader closer io.Closer cursor @@ -111,7 +124,7 @@ func (c *reader) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { r = p } case sqlite3.TEXT: - f, err := os.Open(arg[0].Text()) + f, err := c.fsys.Open(arg[0].Text()) if err != nil { return err } diff --git a/internal/util/osfs.go b/internal/util/osfs.go new file mode 100644 index 0000000..5bce1ce --- /dev/null +++ b/internal/util/osfs.go @@ -0,0 +1,24 @@ +package util + +import ( + "io/fs" + "os" +) + +type OSFS struct{} + +func (OSFS) Open(name string) (fs.File, error) { + return os.Open(name) +} + +func (OSFS) Stat(name string) (fs.FileInfo, error) { + return os.Stat(name) +} + +func (OSFS) ReadDir(name string) ([]fs.DirEntry, error) { + return os.ReadDir(name) +} + +func (OSFS) ReadFile(name string) ([]byte, error) { + return os.ReadFile(name) +}