Refactor extensions.

This commit is contained in:
Nuno Cruces
2024-01-03 00:54:30 +00:00
parent fab70ddbec
commit ae850191c8
23 changed files with 491 additions and 256 deletions

36
ext/csv/arg.go Normal file
View File

@@ -0,0 +1,36 @@
package csv
import (
"fmt"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
func uintArg(key, val string) (int, error) {
i, err := strconv.ParseUint(val, 10, 15)
if err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return int(i), nil
}
func boolArg(key, val string) (bool, error) {
if val == "" {
return true, nil
}
b, ok := util.ParseBool(val)
if ok {
return b, nil
}
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
func runeArg(key, val string) (rune, error) {
r, _, tail, err := strconv.UnquoteChar(vtabutil.Unquote(val), 0)
if tail != "" || err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return r, nil
}

View File

@@ -1,8 +1,12 @@
package csv
import "testing"
import (
"testing"
func Test_uintParam(t *testing.T) {
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
func Test_uintArg(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -20,22 +24,22 @@ func Test_uintParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := uintParam(key, val)
got, err := uintArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("uintParam() error = %v, want err %v", err, tt.err)
t.Fatalf("uintArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("uintParam() = %v, want %v", got, tt.val)
t.Errorf("uintArg() = %v, want %v", got, tt.val)
}
})
}
}
func Test_boolParam(t *testing.T) {
func Test_boolArg(t *testing.T) {
tests := []struct {
arg string
key string
@@ -56,22 +60,22 @@ func Test_boolParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := boolParam(key, val)
got, err := boolArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("boolParam() error = %v, want err %v", err, tt.err)
t.Fatalf("boolArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("boolParam() = %v, want %v", got, tt.val)
t.Errorf("boolArg() = %v, want %v", got, tt.val)
}
})
}
}
func Test_runeParam(t *testing.T) {
func Test_runeArg(t *testing.T) {
tests := []struct {
arg string
key string
@@ -88,16 +92,16 @@ func Test_runeParam(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
key, val := vtabutil.NamedArg(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
}
got, err := runeParam(key, val)
got, err := runeArg(key, val)
if (err != nil) != tt.err {
t.Fatalf("runeParam() error = %v, want err %v", err, tt.err)
t.Fatalf("runeArg() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("runeParam() = %v, want %v", got, tt.val)
t.Errorf("runeArg() = %v, want %v", got, tt.val)
}
})
}

View File

@@ -15,13 +15,14 @@ import (
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/fsutil"
"github.com/ncruces/go-sqlite3/util/vtabutil"
)
// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, util.OSFS{})
RegisterFS(db, fsutil.OSFS{})
}
// RegisterFS registers the CSV virtual table.
@@ -40,23 +41,23 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
)
for _, arg := range arg {
key, val := getParam(arg)
key, val := vtabutil.NamedArg(arg)
if _, ok := done[key]; ok {
return nil, fmt.Errorf("csv: more than one %q parameter", key)
}
switch key {
case "filename":
filename = unquoteParam(val)
filename = vtabutil.Unquote(val)
case "data":
data = unquoteParam(val)
data = vtabutil.Unquote(val)
case "schema":
schema = unquoteParam(val)
schema = vtabutil.Unquote(val)
case "header":
header, err = boolParam(key, val)
header, err = boolArg(key, val)
case "columns":
columns, err = uintParam(key, val)
columns, err = uintArg(key, val)
case "comma":
comma, err = runeParam(key, val)
comma, err = runeArg(key, val)
default:
return nil, fmt.Errorf("csv: unknown %q parameter", key)
}
@@ -81,8 +82,8 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
if schema == "" {
var row []string
if header || columns < 0 {
csv, close, err := table.newReader()
defer close.Close()
csv, c, err := table.newReader()
defer c.Close()
if err != nil {
return nil, err
}
@@ -133,13 +134,11 @@ func (t *table) Integrity(schema, table string, flags int) error {
if flags&1 != 0 {
return nil
}
csv, close, err := t.newReader()
csv, c, err := t.newReader()
if err != nil {
return err
}
if close != nil {
defer close.Close()
}
defer c.Close()
_, err = csv.ReadAll()
return err
}
@@ -176,20 +175,28 @@ func (t *table) newReader() (*csv.Reader, io.Closer, error) {
}
type cursor struct {
table *table
close io.Closer
csv *csv.Reader
row []string
rowID int64
table *table
closer io.Closer
csv *csv.Reader
row []string
rowID int64
}
func (c *cursor) Close() error {
return c.close.Close()
func (c *cursor) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
var err error
c.csv, c.close, err = c.table.newReader()
err := c.Close()
if err != nil {
return err
}
c.csv, c.closer, err = c.table.newReader()
if err != nil {
return err
}

View File

@@ -1,65 +0,0 @@
package csv
import (
"fmt"
"strconv"
"strings"
)
func getParam(arg string) (key, val string) {
key, val, _ = strings.Cut(arg, "=")
key = strings.TrimSpace(key)
val = strings.TrimSpace(val)
return
}
func uintParam(key, val string) (int, error) {
i, err := strconv.ParseUint(val, 10, 15)
if err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return int(i), nil
}
func boolParam(key, val string) (bool, error) {
if val == "" || val == "1" ||
strings.EqualFold(val, "true") ||
strings.EqualFold(val, "yes") ||
strings.EqualFold(val, "on") {
return true, nil
}
if val == "0" ||
strings.EqualFold(val, "false") ||
strings.EqualFold(val, "no") ||
strings.EqualFold(val, "off") {
return false, nil
}
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
func runeParam(key, val string) (rune, error) {
r, _, tail, err := strconv.UnquoteChar(unquoteParam(val), 0)
if tail != "" || err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return r, nil
}
func unquoteParam(val string) string {
if len(val) < 2 {
return val
}
if val[0] != val[len(val)-1] {
return val
}
var old, new string
switch val[0] {
default:
return val
case '"':
old, new = `""`, `"`
case '\'':
old, new = `''`, `'`
}
return strings.ReplaceAll(val[1:len(val)-1], old, new)
}

View File

@@ -13,13 +13,13 @@ import (
)
// Register registers SQL functions readfile, writefile, lsmode,
// and the eponymous virtual table fsdir.
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
}
// Register registers SQL functions readfile, lsmode,
// and the eponymous virtual table fsdir;
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, 0, lsmode)
@@ -27,8 +27,8 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, module, schema, table string, arg ...string) (fsdir, error) {
err := db.DeclareVtab(`CREATE TABLE x(name,mode,mtime,data,path HIDDEN,dir HIDDEN)`)
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVtab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VtabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})

View File

@@ -48,11 +48,11 @@ func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error {
}
func (d fsdir) Open() (sqlite3.VTabCursor, error) {
return &cursor{fsys: d.fsys}, nil
return &cursor{fsdir: d}, nil
}
type cursor struct {
fsys fs.FS
fsdir
curr entry
next chan entry
done chan struct{}

View File

@@ -10,6 +10,7 @@ import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/fsutil"
)
func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) {
@@ -22,7 +23,7 @@ func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) {
var mode fs.FileMode
if len(arg) > 2 {
mode = fixMode(fs.FileMode(arg[2].Int()))
mode = fsutil.FileModeFromValue(arg[2])
}
n, err := createFileAndDir(file, mode, arg[1])
@@ -88,40 +89,6 @@ func createFile(path string, mode fs.FileMode, data sqlite3.Value) (int, error)
return 0, fmt.Errorf("invalid mode: %v", mode)
}
func fixMode(mode fs.FileMode) fs.FileMode {
const (
S_IFMT fs.FileMode = 0170000
S_IFIFO fs.FileMode = 0010000
S_IFCHR fs.FileMode = 0020000
S_IFDIR fs.FileMode = 0040000
S_IFBLK fs.FileMode = 0060000
S_IFREG fs.FileMode = 0100000
S_IFLNK fs.FileMode = 0120000
S_IFSOCK fs.FileMode = 0140000
)
switch mode & S_IFMT {
case S_IFDIR:
mode |= fs.ModeDir
case S_IFLNK:
mode |= fs.ModeSymlink
case S_IFBLK:
mode |= fs.ModeDevice
case S_IFCHR:
mode |= fs.ModeCharDevice | fs.ModeDevice
case S_IFIFO:
mode |= fs.ModeNamedPipe
case S_IFSOCK:
mode |= fs.ModeSocket
case S_IFREG, 0:
//
default:
mode |= fs.ModeIrregular
}
return mode &^ S_IFMT
}
func fixPerm(mode fs.FileMode, def fs.FileMode) fs.FileMode {
if mode.Perm() == 0 {
return def

View File

@@ -56,7 +56,7 @@ func Test_writefile(t *testing.T) {
var mode fs.FileMode
var mtime time.Time
var data sql.NullString
err := rows.Scan(&name, &mode, sqlite3.TimeFormatUnixFrac.Scanner(&mtime), &data)
err := rows.Scan(&name, &mode, &mtime, &data)
if err != nil {
t.Fatal(err)
}
@@ -90,26 +90,3 @@ func Test_writefile(t *testing.T) {
t.Log(err)
}
}
func Test_fixMode(t *testing.T) {
tests := []struct {
mode fs.FileMode
want fs.FileMode
}{
{0010754, 0754 | fs.ModeNamedPipe},
{0020754, 0754 | fs.ModeCharDevice | fs.ModeDevice},
{0040754, 0754 | fs.ModeDir},
{0060754, 0754 | fs.ModeDevice},
{0100754, 0754},
{0120754, 0754 | fs.ModeSymlink},
{0140754, 0754 | fs.ModeSocket},
{0170754, 0754 | fs.ModeIrregular},
}
for _, tt := range tests {
t.Run(tt.mode.String(), func(t *testing.T) {
if got := fixMode(tt.mode); got != tt.want {
t.Errorf("fixMode() = %o, want %o", got, tt.want)
}
})
}
}

View File

@@ -18,20 +18,20 @@ import (
"io/fs"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/fsutil"
)
// Register registers the lines and lines_read virtual tables.
// The lines virtual table reads from a database blob or text.
// The lines_read virtual table reads from a file or an [io.Reader].
// Register registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, util.OSFS{})
RegisterFS(db, fsutil.OSFS{})
}
// RegisterFS registers the lines and lines_read virtual tables.
// The lines virtual table reads from a database blob or text.
// The lines_read virtual table reads from a file or an [io.Reader].
// RegisterFS registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
sqlite3.CreateModule[lines](db, "lines", nil,

View File

@@ -1,24 +0,0 @@
package util
import (
"io/fs"
"os"
)
type OSFS struct{}
func (OSFS) Open(name string) (fs.File, error) {
return os.Open(name)
}
func (OSFS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
func (OSFS) ReadDir(name string) ([]fs.DirEntry, error) {
return os.ReadDir(name)
}
func (OSFS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}

View File

@@ -0,0 +1,15 @@
package util_test
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func TestUnwrapPointer(t *testing.T) {
p := util.Pointer[float64]{Value: math.Pi}
if got := util.UnwrapPointer(p); got != math.Pi {
t.Errorf("want π, got %v", got)
}
}

10
time.go
View File

@@ -344,7 +344,11 @@ type timeScanner struct {
TimeFormat
}
func (s timeScanner) Scan(src any) (err error) {
*s.Time, err = s.Decode(src)
return
func (s timeScanner) Scan(src any) error {
var ok bool
var err error
if *s.Time, ok = src.(time.Time); !ok {
*s.Time, err = s.Decode(src)
}
return err
}

95
util/fsutil/mode.go Normal file
View File

@@ -0,0 +1,95 @@
package fsutil
import (
"io/fs"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ParseFileMode parses a file mode as returned by
// [fs.FileMode.String].
func ParseFileMode(str string) (fs.FileMode, error) {
var mode fs.FileMode
err := util.ErrorString("invalid mode: " + str)
if len(str) < 10 {
return 0, err
}
for i, c := range []byte("dalTLDpSugct?") {
if str[0] == c {
if len(str) < 10 {
return 0, err
}
mode |= 1 << uint(32-1-i)
str = str[1:]
}
}
if mode == 0 {
if str[0] != '-' {
return 0, err
}
str = str[1:]
}
if len(str) != 9 {
return 0, err
}
for i, c := range []byte("rwxrwxrwx") {
if str[i] == c {
mode |= 1 << uint(9-1-i)
}
if str[i] != '-' {
return 0, err
}
}
return mode, nil
}
// FileModeFromUnix converts a POSIX mode_t to a file mode.
func FileModeFromUnix(mode fs.FileMode) fs.FileMode {
const (
S_IFMT fs.FileMode = 0170000
S_IFIFO fs.FileMode = 0010000
S_IFCHR fs.FileMode = 0020000
S_IFDIR fs.FileMode = 0040000
S_IFBLK fs.FileMode = 0060000
S_IFREG fs.FileMode = 0100000
S_IFLNK fs.FileMode = 0120000
S_IFSOCK fs.FileMode = 0140000
)
switch mode & S_IFMT {
case S_IFDIR:
mode |= fs.ModeDir
case S_IFLNK:
mode |= fs.ModeSymlink
case S_IFBLK:
mode |= fs.ModeDevice
case S_IFCHR:
mode |= fs.ModeCharDevice | fs.ModeDevice
case S_IFIFO:
mode |= fs.ModeNamedPipe
case S_IFSOCK:
mode |= fs.ModeSocket
case S_IFREG, 0:
//
default:
mode |= fs.ModeIrregular
}
return mode &^ S_IFMT
}
// FileModeFromValue calls [FileModeFromUnix] for numeric values,
// and [ParseFileMode] for textual values.
func FileModeFromValue(val sqlite3.Value) fs.FileMode {
if n := val.Int64(); n != 0 {
return FileModeFromUnix(fs.FileMode(n))
}
mode, _ := ParseFileMode(val.Text())
return mode
}

54
util/fsutil/mode_test.go Normal file
View File

@@ -0,0 +1,54 @@
package fsutil
import (
"io/fs"
"testing"
)
func TestFileModeFromUnix(t *testing.T) {
tests := []struct {
mode fs.FileMode
want fs.FileMode
}{
{0010754, 0754 | fs.ModeNamedPipe},
{0020754, 0754 | fs.ModeCharDevice | fs.ModeDevice},
{0040754, 0754 | fs.ModeDir},
{0060754, 0754 | fs.ModeDevice},
{0100754, 0754},
{0120754, 0754 | fs.ModeSymlink},
{0140754, 0754 | fs.ModeSocket},
{0170754, 0754 | fs.ModeIrregular},
}
for _, tt := range tests {
t.Run(tt.mode.String(), func(t *testing.T) {
if got := FileModeFromUnix(tt.mode); got != tt.want {
t.Errorf("fixMode() = %o, want %o", got, tt.want)
}
})
}
}
func FuzzParseFileMode(f *testing.F) {
f.Add("---------")
f.Add("rwxrwxrwx")
f.Add("----------")
f.Add("-rwxrwxrwx")
f.Add("b")
f.Add("b---------")
f.Add("drwxrwxrwx")
f.Add("dalTLDpSugct?")
f.Add("dalTLDpSugct?---------")
f.Add("dalTLDpSugct?rwxrwxrwx")
f.Add("dalTLDpSugct?----------")
f.Fuzz(func(t *testing.T, str string) {
mode, err := ParseFileMode(str)
if err != nil {
return
}
got := mode.String()
if got != str {
t.Errorf("was %q, got %q (%o)", str, got, mode)
}
})
}

34
util/fsutil/osfs.go Normal file
View File

@@ -0,0 +1,34 @@
// Package fsutil implements file system utility functions.
package fsutil
import (
"io/fs"
"os"
)
// OSFS implements [fs.FS], [fs.StatFS], and [fs.ReadFileFS]
// using package [os].
//
// This filesystem does not respect [fs.ValidPath] rules,
// and fails [testing/fstest.TestFS]!
//
// Still, it can be a useful tool to unify implementations
// that can access either the [os] filesystem or an [fs.FS].
// It's OK to use this to open files, but you should avoid
// opening directories, resolving paths, or walking the file system.
type OSFS struct{}
// Open implements [fs.FS].
func (OSFS) Open(name string) (fs.File, error) {
return os.Open(name)
}
// ReadFileFS implements [fs.StatFS].
func (OSFS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
// ReadFile implements [fs.ReadFileFS].
func (OSFS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}

60
util/ioutil/seek.go Normal file
View File

@@ -0,0 +1,60 @@
package ioutil
import (
"io"
"sync"
)
// SeekingReaderAt implements [io.ReaderAt]
// through an underlying [io.ReadSeeker].
type SeekingReaderAt struct {
l sync.Mutex
r io.ReadSeeker
}
// NewSeekingReaderAt creates a new SeekingReaderAt.
// The SeekingReaderAt takes ownership of r
// and will modify its seek offset,
// so callers should not use r after this call.
func NewSeekingReaderAt(r io.ReadSeeker) *SeekingReaderAt {
return &SeekingReaderAt{r: r}
}
// ReadAt implements [io.ReaderAt].
func (s *SeekingReaderAt) ReadAt(p []byte, off int64) (n int, _ error) {
s.l.Lock()
defer s.l.Unlock()
_, err := s.r.Seek(off, io.SeekStart)
if err != nil {
return 0, err
}
for len(p) > 0 {
i, err := s.r.Read(p)
p = p[i:]
n += i
if err != nil {
return n, err
}
}
return n, nil
}
// Size implements [SizeReaderAt].
func (s *SeekingReaderAt) Size() (int64, error) {
s.l.Lock()
defer s.l.Unlock()
return s.r.Seek(0, io.SeekEnd)
}
// ReadAt implements [io.Closer].
func (s *SeekingReaderAt) Close() error {
s.l.Lock()
defer s.l.Unlock()
if c, ok := s.r.(io.Closer); ok {
s.r = nil
return c.Close()
}
return nil
}

28
util/ioutil/seek_test.go Normal file
View File

@@ -0,0 +1,28 @@
package ioutil
import (
"strings"
"testing"
)
func TestNewSeekingReaderAt(t *testing.T) {
reader := NewSeekingReaderAt(strings.NewReader("abc"))
defer reader.Close()
n, err := reader.Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
var buf [3]byte
r, err := reader.ReadAt(buf[:], 0)
if err != nil {
t.Fatal(err)
}
if r != 3 {
t.Errorf("got %d", r)
}
}

49
util/ioutil/size.go Normal file
View File

@@ -0,0 +1,49 @@
// Package ioutil implements I/O utility functions.
package ioutil
import (
"io"
"io/fs"
"github.com/ncruces/go-sqlite3"
)
// A SizeReaderAt is a ReaderAt with a Size method.
// Use [NewSizeReaderAt] to adapt different Size interfaces.
type SizeReaderAt interface {
Size() (int64, error)
io.ReaderAt
}
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
// that implements one of:
// - Size() (int64, error)
// - Size() int64
// - Len() int
// - Stat() (fs.FileInfo, error)
// - Seek(offset int64, whence int) (int64, error)
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
return sizer{r}
}
type sizer struct{ io.ReaderAt }
func (s sizer) Size() (int64, error) {
switch s := s.ReaderAt.(type) {
case interface{ Size() (int64, error) }:
return s.Size()
case interface{ Size() int64 }:
return s.Size(), nil
case interface{ Len() int }:
return int64(s.Len()), nil
case interface{ Stat() (fs.FileInfo, error) }:
fi, err := s.Stat()
if err != nil {
return 0, err
}
return fi.Size(), nil
case io.Seeker:
return s.Seek(0, io.SeekEnd)
}
return 0, sqlite3.IOERR_SEEK
}

View File

@@ -1,4 +1,4 @@
package readervfs
package ioutil
import (
"io"

34
util/vtabutil/arg.go Normal file
View File

@@ -0,0 +1,34 @@
// Package ioutil implements virtual table utility functions.
package vtabutil
import "strings"
// NamedArg splits an named arg into a key and value,
// around an equals sign.
// Spaces are trimmed around both key and value.
func NamedArg(arg string) (key, val string) {
key, val, _ = strings.Cut(arg, "=")
key = strings.TrimSpace(key)
val = strings.TrimSpace(val)
return
}
// Unquote unquotes a string.
func Unquote(val string) string {
if len(val) < 2 {
return val
}
if val[0] != val[len(val)-1] {
return val
}
var old, new string
switch val[0] {
default:
return val
case '"':
old, new = `""`, `"`
case '\'':
old, new = `''`, `'`
}
return strings.ReplaceAll(val[1:len(val)-1], old, new)
}

View File

@@ -9,11 +9,9 @@
package readervfs
import (
"io"
"io/fs"
"sync"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs"
)
@@ -24,13 +22,13 @@ func init() {
var (
readerMtx sync.RWMutex
// +checklocks:readerMtx
readerDBs = map[string]SizeReaderAt{}
readerDBs = map[string]ioutil.SizeReaderAt{}
)
// Create creates an immutable database from reader.
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader SizeReaderAt) {
func Create(name string, reader ioutil.SizeReaderAt) {
readerMtx.Lock()
defer readerMtx.Unlock()
readerDBs[name] = reader
@@ -42,43 +40,3 @@ func Delete(name string) {
defer readerMtx.Unlock()
delete(readerDBs, name)
}
// A SizeReaderAt is a ReaderAt with a Size method.
// Use [NewSizeReaderAt] to adapt different Size interfaces.
type SizeReaderAt interface {
Size() (int64, error)
io.ReaderAt
}
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
// that implements one of:
// - Size() (int64, error)
// - Size() int64
// - Len() int
// - Stat() (fs.FileInfo, error)
// - Seek(offset int64, whence int) (int64, error)
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
return sizer{r}
}
type sizer struct{ io.ReaderAt }
func (s sizer) Size() (int64, error) {
switch s := s.ReaderAt.(type) {
case interface{ Size() (int64, error) }:
return s.Size()
case interface{ Size() int64 }:
return s.Size(), nil
case interface{ Len() int }:
return int64(s.Len()), nil
case interface{ Stat() (fs.FileInfo, error) }:
fi, err := s.Stat()
if err != nil {
return 0, err
}
return fi.Size(), nil
case io.Seeker:
return s.Seek(0, io.SeekEnd)
}
return 0, sqlite3.IOERR_SEEK
}

View File

@@ -10,6 +10,7 @@ import (
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs/readervfs"
"github.com/psanford/httpreadat"
)
@@ -65,7 +66,7 @@ func Example_http() {
}
func Example_embed() {
readervfs.Create("test.db", readervfs.NewSizeReaderAt(strings.NewReader(testDB)))
readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB)))
defer readervfs.Delete("test.db")
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")

View File

@@ -2,6 +2,7 @@ package readervfs
import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs"
)
@@ -31,7 +32,7 @@ func (readerVFS) FullPathname(name string) (string, error) {
return name, nil
}
type readerFile struct{ SizeReaderAt }
type readerFile struct{ ioutil.SizeReaderAt }
func (readerFile) Close() error {
return nil