diff --git a/ext/csv/arg.go b/ext/csv/arg.go new file mode 100644 index 0000000..247dfef --- /dev/null +++ b/ext/csv/arg.go @@ -0,0 +1,36 @@ +package csv + +import ( + "fmt" + "strconv" + + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/vtabutil" +) + +func uintArg(key, val string) (int, error) { + i, err := strconv.ParseUint(val, 10, 15) + if err != nil { + return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val) + } + return int(i), nil +} + +func boolArg(key, val string) (bool, error) { + if val == "" { + return true, nil + } + b, ok := util.ParseBool(val) + if ok { + return b, nil + } + return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val) +} + +func runeArg(key, val string) (rune, error) { + r, _, tail, err := strconv.UnquoteChar(vtabutil.Unquote(val), 0) + if tail != "" || err != nil { + return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val) + } + return r, nil +} diff --git a/ext/csv/params_test.go b/ext/csv/arg_test.go similarity index 65% rename from ext/csv/params_test.go rename to ext/csv/arg_test.go index d7f0584..50d4c47 100644 --- a/ext/csv/params_test.go +++ b/ext/csv/arg_test.go @@ -1,8 +1,12 @@ package csv -import "testing" +import ( + "testing" -func Test_uintParam(t *testing.T) { + "github.com/ncruces/go-sqlite3/util/vtabutil" +) + +func Test_uintArg(t *testing.T) { t.Parallel() tests := []struct { @@ -20,22 +24,22 @@ func Test_uintParam(t *testing.T) { } for _, tt := range tests { t.Run(tt.arg, func(t *testing.T) { - key, val := getParam(tt.arg) + key, val := vtabutil.NamedArg(tt.arg) if key != tt.key { - t.Errorf("getParam() %v, want err %v", key, tt.key) + t.Errorf("NamedArg() %v, want err %v", key, tt.key) } - got, err := uintParam(key, val) + got, err := uintArg(key, val) if (err != nil) != tt.err { - t.Fatalf("uintParam() error = %v, want err %v", err, tt.err) + t.Fatalf("uintArg() error = %v, want err %v", err, tt.err) } if got != tt.val { - t.Errorf("uintParam() = %v, want %v", got, tt.val) + t.Errorf("uintArg() = %v, want %v", got, tt.val) } }) } } -func Test_boolParam(t *testing.T) { +func Test_boolArg(t *testing.T) { tests := []struct { arg string key string @@ -56,22 +60,22 @@ func Test_boolParam(t *testing.T) { } for _, tt := range tests { t.Run(tt.arg, func(t *testing.T) { - key, val := getParam(tt.arg) + key, val := vtabutil.NamedArg(tt.arg) if key != tt.key { - t.Errorf("getParam() %v, want err %v", key, tt.key) + t.Errorf("NamedArg() %v, want err %v", key, tt.key) } - got, err := boolParam(key, val) + got, err := boolArg(key, val) if (err != nil) != tt.err { - t.Fatalf("boolParam() error = %v, want err %v", err, tt.err) + t.Fatalf("boolArg() error = %v, want err %v", err, tt.err) } if got != tt.val { - t.Errorf("boolParam() = %v, want %v", got, tt.val) + t.Errorf("boolArg() = %v, want %v", got, tt.val) } }) } } -func Test_runeParam(t *testing.T) { +func Test_runeArg(t *testing.T) { tests := []struct { arg string key string @@ -88,16 +92,16 @@ func Test_runeParam(t *testing.T) { } for _, tt := range tests { t.Run(tt.arg, func(t *testing.T) { - key, val := getParam(tt.arg) + key, val := vtabutil.NamedArg(tt.arg) if key != tt.key { - t.Errorf("getParam() %v, want err %v", key, tt.key) + t.Errorf("NamedArg() %v, want err %v", key, tt.key) } - got, err := runeParam(key, val) + got, err := runeArg(key, val) if (err != nil) != tt.err { - t.Fatalf("runeParam() error = %v, want err %v", err, tt.err) + t.Fatalf("runeArg() error = %v, want err %v", err, tt.err) } if got != tt.val { - t.Errorf("runeParam() = %v, want %v", got, tt.val) + t.Errorf("runeArg() = %v, want %v", got, tt.val) } }) } diff --git a/ext/csv/csv.go b/ext/csv/csv.go index b554145..a8a368b 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -15,13 +15,14 @@ import ( "strings" "github.com/ncruces/go-sqlite3" - "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/fsutil" + "github.com/ncruces/go-sqlite3/util/vtabutil" ) // Register registers the CSV virtual table. // If a filename is specified, [os.Open] is used to open the file. func Register(db *sqlite3.Conn) { - RegisterFS(db, util.OSFS{}) + RegisterFS(db, fsutil.OSFS{}) } // RegisterFS registers the CSV virtual table. @@ -40,23 +41,23 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { ) for _, arg := range arg { - key, val := getParam(arg) + key, val := vtabutil.NamedArg(arg) if _, ok := done[key]; ok { return nil, fmt.Errorf("csv: more than one %q parameter", key) } switch key { case "filename": - filename = unquoteParam(val) + filename = vtabutil.Unquote(val) case "data": - data = unquoteParam(val) + data = vtabutil.Unquote(val) case "schema": - schema = unquoteParam(val) + schema = vtabutil.Unquote(val) case "header": - header, err = boolParam(key, val) + header, err = boolArg(key, val) case "columns": - columns, err = uintParam(key, val) + columns, err = uintArg(key, val) case "comma": - comma, err = runeParam(key, val) + comma, err = runeArg(key, val) default: return nil, fmt.Errorf("csv: unknown %q parameter", key) } @@ -81,8 +82,8 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { if schema == "" { var row []string if header || columns < 0 { - csv, close, err := table.newReader() - defer close.Close() + csv, c, err := table.newReader() + defer c.Close() if err != nil { return nil, err } @@ -133,13 +134,11 @@ func (t *table) Integrity(schema, table string, flags int) error { if flags&1 != 0 { return nil } - csv, close, err := t.newReader() + csv, c, err := t.newReader() if err != nil { return err } - if close != nil { - defer close.Close() - } + defer c.Close() _, err = csv.ReadAll() return err } @@ -176,20 +175,28 @@ func (t *table) newReader() (*csv.Reader, io.Closer, error) { } type cursor struct { - table *table - close io.Closer - csv *csv.Reader - row []string - rowID int64 + table *table + closer io.Closer + csv *csv.Reader + row []string + rowID int64 } -func (c *cursor) Close() error { - return c.close.Close() +func (c *cursor) Close() (err error) { + if c.closer != nil { + err = c.closer.Close() + c.closer = nil + } + return err } func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { - var err error - c.csv, c.close, err = c.table.newReader() + err := c.Close() + if err != nil { + return err + } + + c.csv, c.closer, err = c.table.newReader() if err != nil { return err } diff --git a/ext/csv/params.go b/ext/csv/params.go deleted file mode 100644 index 7dbbd72..0000000 --- a/ext/csv/params.go +++ /dev/null @@ -1,65 +0,0 @@ -package csv - -import ( - "fmt" - "strconv" - "strings" -) - -func getParam(arg string) (key, val string) { - key, val, _ = strings.Cut(arg, "=") - key = strings.TrimSpace(key) - val = strings.TrimSpace(val) - return -} - -func uintParam(key, val string) (int, error) { - i, err := strconv.ParseUint(val, 10, 15) - if err != nil { - return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val) - } - return int(i), nil -} - -func boolParam(key, val string) (bool, error) { - if val == "" || val == "1" || - strings.EqualFold(val, "true") || - strings.EqualFold(val, "yes") || - strings.EqualFold(val, "on") { - return true, nil - } - if val == "0" || - strings.EqualFold(val, "false") || - strings.EqualFold(val, "no") || - strings.EqualFold(val, "off") { - return false, nil - } - return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val) -} - -func runeParam(key, val string) (rune, error) { - r, _, tail, err := strconv.UnquoteChar(unquoteParam(val), 0) - if tail != "" || err != nil { - return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val) - } - return r, nil -} - -func unquoteParam(val string) string { - if len(val) < 2 { - return val - } - if val[0] != val[len(val)-1] { - return val - } - var old, new string - switch val[0] { - default: - return val - case '"': - old, new = `""`, `"` - case '\'': - old, new = `''`, `'` - } - return strings.ReplaceAll(val[1:len(val)-1], old, new) -} diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index 5066011..6774506 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -13,13 +13,13 @@ import ( ) // Register registers SQL functions readfile, writefile, lsmode, -// and the eponymous virtual table fsdir. +// and the table-valued function fsdir. func Register(db *sqlite3.Conn) { RegisterFS(db, nil) } // Register registers SQL functions readfile, lsmode, -// and the eponymous virtual table fsdir; +// and the table-valued function fsdir; // fsys will be used to read files and list directories. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { db.CreateFunction("lsmode", 1, 0, lsmode) @@ -27,8 +27,8 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { if fsys == nil { db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile) } - sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, module, schema, table string, arg ...string) (fsdir, error) { - err := db.DeclareVtab(`CREATE TABLE x(name,mode,mtime,data,path HIDDEN,dir HIDDEN)`) + sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) { + err := db.DeclareVtab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`) db.VtabConfig(sqlite3.VTAB_DIRECTONLY) return fsdir{fsys}, err }) diff --git a/ext/fileio/fsdir.go b/ext/fileio/fsdir.go index 4228652..ac9ebaf 100644 --- a/ext/fileio/fsdir.go +++ b/ext/fileio/fsdir.go @@ -48,11 +48,11 @@ func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error { } func (d fsdir) Open() (sqlite3.VTabCursor, error) { - return &cursor{fsys: d.fsys}, nil + return &cursor{fsdir: d}, nil } type cursor struct { - fsys fs.FS + fsdir curr entry next chan entry done chan struct{} diff --git a/ext/fileio/write.go b/ext/fileio/write.go index 3aaae6a..4c963a5 100644 --- a/ext/fileio/write.go +++ b/ext/fileio/write.go @@ -10,6 +10,7 @@ import ( "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/fsutil" ) func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) { @@ -22,7 +23,7 @@ func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) { var mode fs.FileMode if len(arg) > 2 { - mode = fixMode(fs.FileMode(arg[2].Int())) + mode = fsutil.FileModeFromValue(arg[2]) } n, err := createFileAndDir(file, mode, arg[1]) @@ -88,40 +89,6 @@ func createFile(path string, mode fs.FileMode, data sqlite3.Value) (int, error) return 0, fmt.Errorf("invalid mode: %v", mode) } -func fixMode(mode fs.FileMode) fs.FileMode { - const ( - S_IFMT fs.FileMode = 0170000 - S_IFIFO fs.FileMode = 0010000 - S_IFCHR fs.FileMode = 0020000 - S_IFDIR fs.FileMode = 0040000 - S_IFBLK fs.FileMode = 0060000 - S_IFREG fs.FileMode = 0100000 - S_IFLNK fs.FileMode = 0120000 - S_IFSOCK fs.FileMode = 0140000 - ) - - switch mode & S_IFMT { - case S_IFDIR: - mode |= fs.ModeDir - case S_IFLNK: - mode |= fs.ModeSymlink - case S_IFBLK: - mode |= fs.ModeDevice - case S_IFCHR: - mode |= fs.ModeCharDevice | fs.ModeDevice - case S_IFIFO: - mode |= fs.ModeNamedPipe - case S_IFSOCK: - mode |= fs.ModeSocket - case S_IFREG, 0: - // - default: - mode |= fs.ModeIrregular - } - - return mode &^ S_IFMT -} - func fixPerm(mode fs.FileMode, def fs.FileMode) fs.FileMode { if mode.Perm() == 0 { return def diff --git a/ext/fileio/write_test.go b/ext/fileio/write_test.go index c81a205..09769f3 100644 --- a/ext/fileio/write_test.go +++ b/ext/fileio/write_test.go @@ -56,7 +56,7 @@ func Test_writefile(t *testing.T) { var mode fs.FileMode var mtime time.Time var data sql.NullString - err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime), &data) + err := rows.Scan(&name, &mode, &mtime, &data) if err != nil { t.Fatal(err) } @@ -90,26 +90,3 @@ func Test_writefile(t *testing.T) { t.Log(err) } } - -func Test_fixMode(t *testing.T) { - tests := []struct { - mode fs.FileMode - want fs.FileMode - }{ - {0010754, 0754 | fs.ModeNamedPipe}, - {0020754, 0754 | fs.ModeCharDevice | fs.ModeDevice}, - {0040754, 0754 | fs.ModeDir}, - {0060754, 0754 | fs.ModeDevice}, - {0100754, 0754}, - {0120754, 0754 | fs.ModeSymlink}, - {0140754, 0754 | fs.ModeSocket}, - {0170754, 0754 | fs.ModeIrregular}, - } - for _, tt := range tests { - t.Run(tt.mode.String(), func(t *testing.T) { - if got := fixMode(tt.mode); got != tt.want { - t.Errorf("fixMode() = %o, want %o", got, tt.want) - } - }) - } -} diff --git a/ext/lines/lines.go b/ext/lines/lines.go index 17d611b..d7171a6 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -18,20 +18,20 @@ import ( "io/fs" "github.com/ncruces/go-sqlite3" - "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/fsutil" ) -// Register registers the lines and lines_read virtual tables. -// The lines virtual table reads from a database blob or text. -// The lines_read virtual table reads from a file or an [io.Reader]. +// Register registers the lines and lines_read table-valued functions. +// The lines function reads from a database blob or text. +// The lines_read function reads from a file or an [io.Reader]. // If a filename is specified, [os.Open] is used to open the file. func Register(db *sqlite3.Conn) { - RegisterFS(db, util.OSFS{}) + RegisterFS(db, fsutil.OSFS{}) } -// RegisterFS registers the lines and lines_read virtual tables. -// The lines virtual table reads from a database blob or text. -// The lines_read virtual table reads from a file or an [io.Reader]. +// RegisterFS registers the lines and lines_read table-valued functions. +// The lines function reads from a database blob or text. +// The lines_read function reads from a file or an [io.Reader]. // If a filename is specified, fsys is used to open the file. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { sqlite3.CreateModule[lines](db, "lines", nil, diff --git a/internal/util/osfs.go b/internal/util/osfs.go deleted file mode 100644 index 5bce1ce..0000000 --- a/internal/util/osfs.go +++ /dev/null @@ -1,24 +0,0 @@ -package util - -import ( - "io/fs" - "os" -) - -type OSFS struct{} - -func (OSFS) Open(name string) (fs.File, error) { - return os.Open(name) -} - -func (OSFS) Stat(name string) (fs.FileInfo, error) { - return os.Stat(name) -} - -func (OSFS) ReadDir(name string) ([]fs.DirEntry, error) { - return os.ReadDir(name) -} - -func (OSFS) ReadFile(name string) ([]byte, error) { - return os.ReadFile(name) -} diff --git a/internal/util/pointer_test.go b/internal/util/pointer_test.go new file mode 100644 index 0000000..d33d4f3 --- /dev/null +++ b/internal/util/pointer_test.go @@ -0,0 +1,15 @@ +package util_test + +import ( + "math" + "testing" + + "github.com/ncruces/go-sqlite3/internal/util" +) + +func TestUnwrapPointer(t *testing.T) { + p := util.Pointer[float64]{Value: math.Pi} + if got := util.UnwrapPointer(p); got != math.Pi { + t.Errorf("want π, got %v", got) + } +} diff --git a/time.go b/time.go index 47ac672..a14870e 100644 --- a/time.go +++ b/time.go @@ -344,7 +344,11 @@ type timeScanner struct { TimeFormat } -func (s timeScanner) Scan(src any) (err error) { - *s.Time, err = s.Decode(src) - return +func (s timeScanner) Scan(src any) error { + var ok bool + var err error + if *s.Time, ok = src.(time.Time); !ok { + *s.Time, err = s.Decode(src) + } + return err } diff --git a/util/fsutil/mode.go b/util/fsutil/mode.go new file mode 100644 index 0000000..c38fef5 --- /dev/null +++ b/util/fsutil/mode.go @@ -0,0 +1,95 @@ +package fsutil + +import ( + "io/fs" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" +) + +// ParseFileMode parses a file mode as returned by +// [fs.FileMode.String]. +func ParseFileMode(str string) (fs.FileMode, error) { + var mode fs.FileMode + err := util.ErrorString("invalid mode: " + str) + + if len(str) < 10 { + return 0, err + } + + for i, c := range []byte("dalTLDpSugct?") { + if str[0] == c { + if len(str) < 10 { + return 0, err + } + mode |= 1 << uint(32-1-i) + str = str[1:] + } + } + + if mode == 0 { + if str[0] != '-' { + return 0, err + } + str = str[1:] + } + if len(str) != 9 { + return 0, err + } + + for i, c := range []byte("rwxrwxrwx") { + if str[i] == c { + mode |= 1 << uint(9-1-i) + } + if str[i] != '-' { + return 0, err + } + } + + return mode, nil +} + +// FileModeFromUnix converts a POSIX mode_t to a file mode. +func FileModeFromUnix(mode fs.FileMode) fs.FileMode { + const ( + S_IFMT fs.FileMode = 0170000 + S_IFIFO fs.FileMode = 0010000 + S_IFCHR fs.FileMode = 0020000 + S_IFDIR fs.FileMode = 0040000 + S_IFBLK fs.FileMode = 0060000 + S_IFREG fs.FileMode = 0100000 + S_IFLNK fs.FileMode = 0120000 + S_IFSOCK fs.FileMode = 0140000 + ) + + switch mode & S_IFMT { + case S_IFDIR: + mode |= fs.ModeDir + case S_IFLNK: + mode |= fs.ModeSymlink + case S_IFBLK: + mode |= fs.ModeDevice + case S_IFCHR: + mode |= fs.ModeCharDevice | fs.ModeDevice + case S_IFIFO: + mode |= fs.ModeNamedPipe + case S_IFSOCK: + mode |= fs.ModeSocket + case S_IFREG, 0: + // + default: + mode |= fs.ModeIrregular + } + + return mode &^ S_IFMT +} + +// FileModeFromValue calls [FileModeFromUnix] for numeric values, +// and [ParseFileMode] for textual values. +func FileModeFromValue(val sqlite3.Value) fs.FileMode { + if n := val.Int64(); n != 0 { + return FileModeFromUnix(fs.FileMode(n)) + } + mode, _ := ParseFileMode(val.Text()) + return mode +} diff --git a/util/fsutil/mode_test.go b/util/fsutil/mode_test.go new file mode 100644 index 0000000..bec7ca7 --- /dev/null +++ b/util/fsutil/mode_test.go @@ -0,0 +1,54 @@ +package fsutil + +import ( + "io/fs" + "testing" +) + +func TestFileModeFromUnix(t *testing.T) { + tests := []struct { + mode fs.FileMode + want fs.FileMode + }{ + {0010754, 0754 | fs.ModeNamedPipe}, + {0020754, 0754 | fs.ModeCharDevice | fs.ModeDevice}, + {0040754, 0754 | fs.ModeDir}, + {0060754, 0754 | fs.ModeDevice}, + {0100754, 0754}, + {0120754, 0754 | fs.ModeSymlink}, + {0140754, 0754 | fs.ModeSocket}, + {0170754, 0754 | fs.ModeIrregular}, + } + for _, tt := range tests { + t.Run(tt.mode.String(), func(t *testing.T) { + if got := FileModeFromUnix(tt.mode); got != tt.want { + t.Errorf("fixMode() = %o, want %o", got, tt.want) + } + }) + } +} + +func FuzzParseFileMode(f *testing.F) { + f.Add("---------") + f.Add("rwxrwxrwx") + f.Add("----------") + f.Add("-rwxrwxrwx") + f.Add("b") + f.Add("b---------") + f.Add("drwxrwxrwx") + f.Add("dalTLDpSugct?") + f.Add("dalTLDpSugct?---------") + f.Add("dalTLDpSugct?rwxrwxrwx") + f.Add("dalTLDpSugct?----------") + + f.Fuzz(func(t *testing.T, str string) { + mode, err := ParseFileMode(str) + if err != nil { + return + } + got := mode.String() + if got != str { + t.Errorf("was %q, got %q (%o)", str, got, mode) + } + }) +} diff --git a/util/fsutil/osfs.go b/util/fsutil/osfs.go new file mode 100644 index 0000000..b807d1d --- /dev/null +++ b/util/fsutil/osfs.go @@ -0,0 +1,34 @@ +// Package fsutil implements file system utility functions. +package fsutil + +import ( + "io/fs" + "os" +) + +// OSFS implements [fs.FS], [fs.StatFS], and [fs.ReadFileFS] +// using package [os]. +// +// This filesystem does not respect [fs.ValidPath] rules, +// and fails [testing/fstest.TestFS]! +// +// Still, it can be a useful tool to unify implementations +// that can access either the [os] filesystem or an [fs.FS]. +// It's OK to use this to open files, but you should avoid +// opening directories, resolving paths, or walking the file system. +type OSFS struct{} + +// Open implements [fs.FS]. +func (OSFS) Open(name string) (fs.File, error) { + return os.Open(name) +} + +// ReadFileFS implements [fs.StatFS]. +func (OSFS) Stat(name string) (fs.FileInfo, error) { + return os.Stat(name) +} + +// ReadFile implements [fs.ReadFileFS]. +func (OSFS) ReadFile(name string) ([]byte, error) { + return os.ReadFile(name) +} diff --git a/util/ioutil/seek.go b/util/ioutil/seek.go new file mode 100644 index 0000000..dc86ea1 --- /dev/null +++ b/util/ioutil/seek.go @@ -0,0 +1,60 @@ +package ioutil + +import ( + "io" + "sync" +) + +// SeekingReaderAt implements [io.ReaderAt] +// through an underlying [io.ReadSeeker]. +type SeekingReaderAt struct { + l sync.Mutex + r io.ReadSeeker +} + +// NewSeekingReaderAt creates a new SeekingReaderAt. +// The SeekingReaderAt takes ownership of r +// and will modify its seek offset, +// so callers should not use r after this call. +func NewSeekingReaderAt(r io.ReadSeeker) *SeekingReaderAt { + return &SeekingReaderAt{r: r} +} + +// ReadAt implements [io.ReaderAt]. +func (s *SeekingReaderAt) ReadAt(p []byte, off int64) (n int, _ error) { + s.l.Lock() + defer s.l.Unlock() + + _, err := s.r.Seek(off, io.SeekStart) + if err != nil { + return 0, err + } + + for len(p) > 0 { + i, err := s.r.Read(p) + p = p[i:] + n += i + if err != nil { + return n, err + } + } + return n, nil +} + +// Size implements [SizeReaderAt]. +func (s *SeekingReaderAt) Size() (int64, error) { + s.l.Lock() + defer s.l.Unlock() + return s.r.Seek(0, io.SeekEnd) +} + +// ReadAt implements [io.Closer]. +func (s *SeekingReaderAt) Close() error { + s.l.Lock() + defer s.l.Unlock() + if c, ok := s.r.(io.Closer); ok { + s.r = nil + return c.Close() + } + return nil +} diff --git a/util/ioutil/seek_test.go b/util/ioutil/seek_test.go new file mode 100644 index 0000000..e2716bb --- /dev/null +++ b/util/ioutil/seek_test.go @@ -0,0 +1,28 @@ +package ioutil + +import ( + "strings" + "testing" +) + +func TestNewSeekingReaderAt(t *testing.T) { + reader := NewSeekingReaderAt(strings.NewReader("abc")) + defer reader.Close() + + n, err := reader.Size() + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Errorf("got %d", n) + } + + var buf [3]byte + r, err := reader.ReadAt(buf[:], 0) + if err != nil { + t.Fatal(err) + } + if r != 3 { + t.Errorf("got %d", r) + } +} diff --git a/util/ioutil/size.go b/util/ioutil/size.go new file mode 100644 index 0000000..8c40beb --- /dev/null +++ b/util/ioutil/size.go @@ -0,0 +1,49 @@ +// Package ioutil implements I/O utility functions. +package ioutil + +import ( + "io" + "io/fs" + + "github.com/ncruces/go-sqlite3" +) + +// A SizeReaderAt is a ReaderAt with a Size method. +// Use [NewSizeReaderAt] to adapt different Size interfaces. +type SizeReaderAt interface { + Size() (int64, error) + io.ReaderAt +} + +// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt +// that implements one of: +// - Size() (int64, error) +// - Size() int64 +// - Len() int +// - Stat() (fs.FileInfo, error) +// - Seek(offset int64, whence int) (int64, error) +func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt { + return sizer{r} +} + +type sizer struct{ io.ReaderAt } + +func (s sizer) Size() (int64, error) { + switch s := s.ReaderAt.(type) { + case interface{ Size() (int64, error) }: + return s.Size() + case interface{ Size() int64 }: + return s.Size(), nil + case interface{ Len() int }: + return int64(s.Len()), nil + case interface{ Stat() (fs.FileInfo, error) }: + fi, err := s.Stat() + if err != nil { + return 0, err + } + return fi.Size(), nil + case io.Seeker: + return s.Seek(0, io.SeekEnd) + } + return 0, sqlite3.IOERR_SEEK +} diff --git a/vfs/readervfs/reader_test.go b/util/ioutil/size_test.go similarity index 98% rename from vfs/readervfs/reader_test.go rename to util/ioutil/size_test.go index 4508a35..dbaf209 100644 --- a/vfs/readervfs/reader_test.go +++ b/util/ioutil/size_test.go @@ -1,4 +1,4 @@ -package readervfs +package ioutil import ( "io" diff --git a/util/vtabutil/arg.go b/util/vtabutil/arg.go new file mode 100644 index 0000000..15e2b74 --- /dev/null +++ b/util/vtabutil/arg.go @@ -0,0 +1,34 @@ +// Package ioutil implements virtual table utility functions. +package vtabutil + +import "strings" + +// NamedArg splits an named arg into a key and value, +// around an equals sign. +// Spaces are trimmed around both key and value. +func NamedArg(arg string) (key, val string) { + key, val, _ = strings.Cut(arg, "=") + key = strings.TrimSpace(key) + val = strings.TrimSpace(val) + return +} + +// Unquote unquotes a string. +func Unquote(val string) string { + if len(val) < 2 { + return val + } + if val[0] != val[len(val)-1] { + return val + } + var old, new string + switch val[0] { + default: + return val + case '"': + old, new = `""`, `"` + case '\'': + old, new = `''`, `'` + } + return strings.ReplaceAll(val[1:len(val)-1], old, new) +} diff --git a/vfs/readervfs/api.go b/vfs/readervfs/api.go index 8cc38ef..ff223d8 100644 --- a/vfs/readervfs/api.go +++ b/vfs/readervfs/api.go @@ -9,11 +9,9 @@ package readervfs import ( - "io" - "io/fs" "sync" - "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/util/ioutil" "github.com/ncruces/go-sqlite3/vfs" ) @@ -24,13 +22,13 @@ func init() { var ( readerMtx sync.RWMutex // +checklocks:readerMtx - readerDBs = map[string]SizeReaderAt{} + readerDBs = map[string]ioutil.SizeReaderAt{} ) // Create creates an immutable database from reader. // The caller should ensure that data from reader does not mutate, // otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors. -func Create(name string, reader SizeReaderAt) { +func Create(name string, reader ioutil.SizeReaderAt) { readerMtx.Lock() defer readerMtx.Unlock() readerDBs[name] = reader @@ -42,43 +40,3 @@ func Delete(name string) { defer readerMtx.Unlock() delete(readerDBs, name) } - -// A SizeReaderAt is a ReaderAt with a Size method. -// Use [NewSizeReaderAt] to adapt different Size interfaces. -type SizeReaderAt interface { - Size() (int64, error) - io.ReaderAt -} - -// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt -// that implements one of: -// - Size() (int64, error) -// - Size() int64 -// - Len() int -// - Stat() (fs.FileInfo, error) -// - Seek(offset int64, whence int) (int64, error) -func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt { - return sizer{r} -} - -type sizer struct{ io.ReaderAt } - -func (s sizer) Size() (int64, error) { - switch s := s.ReaderAt.(type) { - case interface{ Size() (int64, error) }: - return s.Size() - case interface{ Size() int64 }: - return s.Size(), nil - case interface{ Len() int }: - return int64(s.Len()), nil - case interface{ Stat() (fs.FileInfo, error) }: - fi, err := s.Stat() - if err != nil { - return 0, err - } - return fi.Size(), nil - case io.Seeker: - return s.Seek(0, io.SeekEnd) - } - return 0, sqlite3.IOERR_SEEK -} diff --git a/vfs/readervfs/example_test.go b/vfs/readervfs/example_test.go index e4bc623..79f2045 100644 --- a/vfs/readervfs/example_test.go +++ b/vfs/readervfs/example_test.go @@ -10,6 +10,7 @@ import ( _ "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/util/ioutil" "github.com/ncruces/go-sqlite3/vfs/readervfs" "github.com/psanford/httpreadat" ) @@ -65,7 +66,7 @@ func Example_http() { } func Example_embed() { - readervfs.Create("test.db", readervfs.NewSizeReaderAt(strings.NewReader(testDB))) + readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) defer readervfs.Delete("test.db") db, err := sql.Open("sqlite3", "file:test.db?vfs=reader") diff --git a/vfs/readervfs/reader.go b/vfs/readervfs/reader.go index 3e47eb6..15b9471 100644 --- a/vfs/readervfs/reader.go +++ b/vfs/readervfs/reader.go @@ -2,6 +2,7 @@ package readervfs import ( "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/util/ioutil" "github.com/ncruces/go-sqlite3/vfs" ) @@ -31,7 +32,7 @@ func (readerVFS) FullPathname(name string) (string, error) { return name, nil } -type readerFile struct{ SizeReaderAt } +type readerFile struct{ ioutil.SizeReaderAt } func (readerFile) Close() error { return nil