diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 4910f25..144c1d7 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -7,10 +7,11 @@ package csv import ( + "bufio" "encoding/csv" "fmt" "io" - "math" + "io/fs" "os" "strings" @@ -18,16 +19,14 @@ import ( ) // Register registers the CSV virtual table. -// If a filename is specified, `os.Open` is used to read it from disk. +// If a filename is specified, `os.Open` is used to open the file. func Register(db *sqlite3.Conn) { - RegisterOpen(db, func(name string) (io.ReaderAt, error) { - return os.Open(name) - }) + RegisterOpen(db, osfs{}) } // RegisterOpen registers the CSV virtual table. -// If a filename is specified, open is used to open the file. -func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) { +// If a filename is specified, fsys is used to open the file. +func RegisterOpen(db *sqlite3.Conn, fsys fs.FS) { declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { var ( filename string @@ -71,32 +70,23 @@ func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) return nil, fmt.Errorf(`csv: must specify either "filename" or "data" but not both`) } - var r io.ReaderAt - if filename != "" { - r, err = open(filename) - } else { - r = strings.NewReader(data) - } - if err != nil { - return nil, err - } - table := &table{ - r: r, + fsys: fsys, + name: filename, + data: data, comma: comma, header: header, - bom: -1, } - defer func() { - if err != nil { - table.Close() - } - }() if schema == "" { var row []string if header || columns < 0 { - row, err = table.newReader().Read() + csv, close, err := table.newReader() + defer close.Close() + if err != nil { + return nil, err + } + row, err = csv.Read() if err != nil { return nil, err } @@ -118,20 +108,18 @@ func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) sqlite3.CreateModule(db, "csv", declare, declare) } -type table struct { - r io.ReaderAt - comma rune - header bool - bom int8 +type osfs struct{} + +func (osfs) Open(name string) (fs.File, error) { + return os.Open(name) } -func (t *table) Close() error { - if c, ok := t.r.(io.Closer); ok { - err := c.Close() - t.r = nil - return err - } - return nil +type table struct { + fsys fs.FS + name string + data string + comma rune + header bool } func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { @@ -147,38 +135,70 @@ func (t *table) Rename(new string) error { return nil } -func (t *table) Integrity(schema, table string, flags int) (err error) { - if flags&1 == 0 { - _, err = t.newReader().ReadAll() +func (t *table) Integrity(schema, table string, flags int) error { + if flags&1 != 0 { + return nil } + csv, close, err := t.newReader() + if err != nil { + return err + } + if close != nil { + defer close.Close() + } + _, err = csv.ReadAll() return err } -func (t *table) newReader() *csv.Reader { - if t.bom < 0 { - var bom [3]byte - t.r.ReadAt(bom[:], 0) - if string(bom[:]) == "\xEF\xBB\xBF" { - t.bom = 3 - } else { - t.bom = 0 +func (t *table) newReader() (*csv.Reader, io.Closer, error) { + var r io.Reader + var c io.Closer + if t.name != "" { + f, err := t.fsys.Open(t.name) + if err != nil { + return nil, f, err } + + buf := bufio.NewReader(f) + bom, err := buf.Peek(3) + if err != nil { + return nil, f, err + } + if string(bom) == "\xEF\xBB\xBF" { + buf.Discard(3) + } + + r = buf + c = f + } else { + r = strings.NewReader(t.data) + c = io.NopCloser(r) } - csv := csv.NewReader(io.NewSectionReader(t.r, int64(t.bom), math.MaxInt64)) + + csv := csv.NewReader(r) csv.ReuseRecord = true csv.Comma = t.comma - return csv + return csv, c, nil } type cursor struct { table *table + close io.Closer csv *csv.Reader row []string rowID int64 } +func (c *cursor) Close() error { + return c.close.Close() +} + func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { - c.csv = c.table.newReader() + var err error + c.csv, c.close, err = c.table.newReader() + if err != nil { + return err + } if c.table.header { c.Next() // skip header } diff --git a/ext/csv/csv_test.go b/ext/csv/csv_test.go index 5431203..7ecdab2 100644 --- a/ext/csv/csv_test.go +++ b/ext/csv/csv_test.go @@ -61,7 +61,7 @@ func TestRegister(t *testing.T) { csv.Register(db) - const data = "\xEF\xBB\xBF" + ` + const data = ` "Rob" "Pike" rob "Ken" Thompson ken Robert "Griesemer" "gri"` @@ -84,8 +84,8 @@ Robert "Griesemer" "gri"` if !stmt.Step() { t.Fatal("no rows") } - if got := stmt.ColumnText(1); got != "Pike" { - t.Errorf("got %q want Pike", got) + if got := stmt.ColumnText(0); got != "Rob" { + t.Errorf("got %q want Rob", got) } if stmt.Step() { t.Fatal("more rows") @@ -98,12 +98,17 @@ Robert "Griesemer" "gri"` err = db.Exec(`PRAGMA integrity_check`) if err != nil { - t.Fatal(err) + t.Error(err) + } + + err = db.Exec(`PRAGMA quick_check`) + if err != nil { + t.Error(err) } err = db.Exec(`DROP TABLE temp.csv`) if err != nil { - log.Fatal(err) + t.Error(err) } } diff --git a/ext/csv/testdata/eurofxref.csv b/ext/csv/testdata/eurofxref.csv index 053ce97..6706a93 100644 --- a/ext/csv/testdata/eurofxref.csv +++ b/ext/csv/testdata/eurofxref.csv @@ -1,4 +1,4 @@ -Date,USD,JPY,BGN,CYP,CZK,DKK,EEK,GBP,HUF,LTL,LVL,MTL,PLN,ROL,RON,SEK,SIT,SKK,CHF,ISK,NOK,HRK,RUB,TRL,TRY,AUD,BRL,CAD,CNY,HKD,IDR,ILS,INR,KRW,MXN,MYR,NZD,PHP,SGD,THB,ZAR, +Date,USD,JPY,BGN,CYP,CZK,DKK,EEK,GBP,HUF,LTL,LVL,MTL,PLN,ROL,RON,SEK,SIT,SKK,CHF,ISK,NOK,HRK,RUB,TRL,TRY,AUD,BRL,CAD,CNY,HKD,IDR,ILS,INR,KRW,MXN,MYR,NZD,PHP,SGD,THB,ZAR, 2022-12-30,1.0666,140.66,1.9558,N/A,24.116,7.4365,N/A,0.88693,400.87,N/A,N/A,N/A,4.6808,N/A,4.9495,11.1218,N/A,N/A,0.9847,151.5,10.5138,7.5365,N/A,N/A,19.9649,1.5693,5.6386,1.444,7.3582,8.3163,16519.82,3.7554,88.171,1344.09,20.856,4.6984,1.6798,59.32,1.43,36.835,18.0986, 2022-12-29,1.0649,142.24,1.9558,N/A,24.191,7.4365,N/A,0.88549,399.6,N/A,N/A,N/A,4.6855,N/A,4.9493,11.158,N/A,N/A,0.984,152.5,10.55,7.5365,N/A,N/A,19.934,1.5859,5.5351,1.4475,7.4151,8.2994,16680.38,3.7575,88.2295,1350.18,20.651,4.7106,1.6887,59.367,1.436,36.877,18.1967, 2022-12-28,1.064,142.21,1.9558,N/A,24.252,7.4365,N/A,0.88058,403.3,N/A,N/A,N/A,4.7008,N/A,4.946,11.1038,N/A,N/A,0.9863,151.9,10.4495,7.5365,N/A,N/A,19.9144,1.566,5.6109,1.4361,7.4224,8.2931,16765.93,3.7526,88.0943,1348.59,20.6856,4.7055,1.6772,59.613,1.4323,36.953,18.289, diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index bb7523a..5066011 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -20,7 +20,7 @@ func Register(db *sqlite3.Conn) { // Register registers SQL functions readfile, lsmode, // and the eponymous virtual table fsdir; -// fs will be used to read files and list directories. +// fsys will be used to read files and list directories. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { db.CreateFunction("lsmode", 1, 0, lsmode) db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)) diff --git a/ext/lines/lines.go b/ext/lines/lines.go index 41a8578..d0a1435 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -15,7 +15,6 @@ import ( "bytes" "fmt" "io" - "math" "os" "github.com/ncruces/go-sqlite3" @@ -23,7 +22,7 @@ import ( // Register registers the lines and lines_read virtual tables. // The lines virtual table reads from a database blob or text. -// The lines_read virtual table reads from a file or an [io.ReaderAt]. +// The lines_read virtual table reads from a file or an [io.Reader]. func Register(db *sqlite3.Conn) { sqlite3.CreateModule[lines](db, "lines", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { @@ -57,81 +56,125 @@ func (l lines) BestIndex(idx *sqlite3.IndexInfo) error { } func (l lines) Open() (sqlite3.VTabCursor, error) { - return &cursor{reader: bool(l)}, nil + if l { + return &reader{}, nil + } else { + return &buffer{}, nil + } } type cursor struct { - scanner *bufio.Scanner - closer io.Closer - rowID int64 - eof bool - reader bool -} - -func (c *cursor) Close() (err error) { - if c.closer != nil { - err = c.closer.Close() - c.closer = nil - } - return err + line []byte + rowID int64 + eof bool } func (c *cursor) EOF() bool { return c.eof } -func (c *cursor) Next() error { - c.rowID++ - c.eof = !c.scanner.Scan() - return c.scanner.Err() -} - func (c *cursor) RowID() (int64, error) { return c.rowID, nil } func (c *cursor) Column(ctx *sqlite3.Context, n int) error { if n == 0 { - ctx.ResultRawText(c.scanner.Bytes()) + ctx.ResultRawText(c.line) } return nil } -func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { +type reader struct { + reader *bufio.Reader + closer io.Closer + cursor +} + +func (c *reader) Close() (err error) { + if c.closer != nil { + err = c.closer.Close() + c.closer = nil + } + return err +} + +func (c *reader) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { if err := c.Close(); err != nil { return err } var r io.Reader - data := arg[0] - typ := data.Type() - if c.reader { - switch typ { - case sqlite3.NULL: - if p, ok := data.Pointer().(io.ReaderAt); ok { - r = io.NewSectionReader(p, 0, math.MaxInt64) - } - case sqlite3.TEXT: - f, err := os.Open(data.Text()) - if err != nil { - return err - } - c.closer = f - r = f + typ := arg[0].Type() + switch typ { + case sqlite3.NULL: + if p, ok := arg[0].Pointer().(io.Reader); ok { + r = p } - } else { - switch typ { - case sqlite3.TEXT: - r = bytes.NewReader(data.RawText()) - case sqlite3.BLOB: - r = bytes.NewReader(data.RawBlob()) + case sqlite3.TEXT: + f, err := os.Open(arg[0].Text()) + if err != nil { + return err } + r = f } - if r == nil { return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ) } - c.scanner = bufio.NewScanner(r) + + c.reader = bufio.NewReader(r) + c.closer, _ = r.(io.Closer) c.rowID = 0 return c.Next() } + +func (c *reader) Next() (err error) { + c.line = c.line[:0] + for more := true; more; { + var line []byte + line, more, err = c.reader.ReadLine() + c.line = append(c.line, line...) + } + if err == io.EOF { + c.eof = true + err = nil + } + c.rowID++ + return err +} + +type buffer struct { + data []byte + cursor +} + +func (c *buffer) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + typ := arg[0].Type() + switch typ { + case sqlite3.TEXT: + c.data = arg[0].RawText() + case sqlite3.BLOB: + c.data = arg[0].RawBlob() + default: + return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ) + } + + c.rowID = 0 + return c.Next() +} + +func (c *buffer) Next() error { + i := bytes.IndexByte(c.data, '\n') + j := i + 1 + switch { + case i < 0: + i = len(c.data) + j = i + case i > 0 && c.data[i-1] == '\r': + i-- + } + c.eof = len(c.data) == 0 + c.line = c.data[:i] + c.data = c.data[j:] + c.rowID++ + return nil +} diff --git a/ext/lines/lines_test.go b/ext/lines/lines_test.go index 310b037..1431717 100644 --- a/ext/lines/lines_test.go +++ b/ext/lines/lines_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "net/http" "os" "strings" "testing" @@ -25,12 +26,11 @@ func Example() { } defer db.Close() - // https://storage.googleapis.com/quickdraw_dataset/full/simplified/calendar.ndjson - f, err := os.Open("calendar.ndjson") + res, err := http.Get("https://storage.googleapis.com/quickdraw_dataset/full/simplified/calendar.ndjson") if err != nil { log.Fatal(err) } - defer f.Close() + defer res.Body.Close() rows, err := db.Query(` SELECT @@ -40,7 +40,7 @@ func Example() { GROUP BY 1 ORDER BY 2 DESC LIMIT 5`, - sqlite3.Pointer(f)) + sqlite3.Pointer(res.Body)) if err != nil { log.Fatal(err) } @@ -58,7 +58,7 @@ func Example() { if err := rows.Err(); err != nil { log.Fatal(err) } - // Sample output: + // Output: // US: 141001 // GB: 22560 // CA: 11759 @@ -78,7 +78,7 @@ func Test_lines(t *testing.T) { } defer db.Close() - const data = "line 1\nline 2\nline 3" + const data = "line 1\nline 2\r\nline 3\n" rows, err := db.Query(`SELECT rowid, line FROM lines(?)`, data) if err != nil { @@ -93,6 +93,9 @@ func Test_lines(t *testing.T) { if err != nil { t.Fatal(err) } + if want := fmt.Sprintf("line %d", id); line != want { + t.Errorf("got %q, want %q", line, want) + } } } @@ -135,7 +138,7 @@ func Test_lines_read(t *testing.T) { } defer db.Close() - const data = "line 1\nline 2\nline 3" + const data = "line 1\nline 2\r\nline 3\n" rows, err := db.Query(`SELECT rowid, line FROM lines_read(?)`, sqlite3.Pointer(strings.NewReader(data))) @@ -151,6 +154,9 @@ func Test_lines_read(t *testing.T) { if err != nil { t.Fatal(err) } + if want := fmt.Sprintf("line %d", id); line != want { + t.Errorf("got %q, want %q", line, want) + } } }