Files
sqlite3/ext/csv/csv.go

212 lines
4.1 KiB
Go
Raw Normal View History

2023-11-23 03:28:56 +00:00
// Package csv provides a CSV virtual table.
//
// The CSV virtual table reads RFC 4180 formatted comma-separated values,
// and returns that content as if it were rows and columns of an SQL table.
//
// https://sqlite.org/csv.html
package csv
import (
"encoding/csv"
"fmt"
"io"
"math"
2023-11-23 09:54:18 +00:00
"os"
2023-11-23 03:28:56 +00:00
"strings"
"github.com/ncruces/go-sqlite3"
)
// Register registers the CSV virtual table.
2023-11-23 09:54:18 +00:00
// If a filename is specified, `os.Open` is used to read it from disk.
func Register(db *sqlite3.Conn) {
RegisterOpen(db, func(name string) (io.ReaderAt, error) {
return os.Open(name)
})
}
// RegisterOpen registers the CSV virtual table.
2023-11-23 03:28:56 +00:00
// If a filename is specified, open is used to open the file.
2023-11-23 09:54:18 +00:00
func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) {
2023-11-29 00:46:27 +00:00
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
2023-11-23 03:28:56 +00:00
var (
filename string
data string
schema string
header bool
columns int = -1
comma rune = ','
2023-11-23 09:54:18 +00:00
done = map[string]struct{}{}
2023-11-23 03:28:56 +00:00
)
2023-11-29 00:46:27 +00:00
for _, arg := range arg {
2023-11-23 03:28:56 +00:00
key, val := getParam(arg)
if _, ok := done[key]; ok {
return nil, fmt.Errorf("csv: more than one %q parameter", key)
}
switch key {
case "filename":
filename = unquoteParam(val)
case "data":
data = unquoteParam(val)
case "schema":
schema = unquoteParam(val)
case "header":
header, err = boolParam(key, val)
case "columns":
columns, err = uintParam(key, val)
case "comma":
comma, err = runeParam(key, val)
default:
return nil, fmt.Errorf("csv: unknown %q parameter", key)
}
if err != nil {
return nil, err
}
done[key] = struct{}{}
}
if (filename == "") == (data == "") {
return nil, fmt.Errorf(`csv: must specify either "filename" or "data" but not both`)
}
var r io.ReaderAt
if filename != "" {
r, err = open(filename)
} else {
r = strings.NewReader(data)
}
if err != nil {
return nil, err
}
table := &table{
r: r,
comma: comma,
header: header,
2023-11-27 23:35:43 +00:00
bom: -1,
2023-11-23 03:28:56 +00:00
}
2023-11-23 09:54:18 +00:00
defer func() {
if err != nil {
table.Close()
}
}()
2023-11-23 03:28:56 +00:00
2023-11-29 00:46:27 +00:00
if schema == "" {
var row []string
if header || columns < 0 {
row, err = table.newReader().Read()
if err != nil {
return nil, err
}
2023-11-23 03:28:56 +00:00
}
schema = getSchema(header, columns, row)
}
err = db.DeclareVtab(schema)
2023-11-23 09:54:18 +00:00
if err != nil {
return nil, err
}
err = db.VtabConfig(sqlite3.VTAB_DIRECTONLY)
if err != nil {
return nil, err
}
return table, nil
2023-11-23 03:28:56 +00:00
}
sqlite3.CreateModule(db, "csv", declare, declare)
}
type table struct {
r io.ReaderAt
comma rune
header bool
2023-11-27 23:35:43 +00:00
bom int8
2023-11-23 03:28:56 +00:00
}
func (t *table) Close() error {
if c, ok := t.r.(io.Closer); ok {
err := c.Close()
t.r = nil
return err
}
return nil
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.EstimatedCost = 1e6
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
2023-11-23 09:54:18 +00:00
func (t *table) Rename(new string) error {
return nil
}
func (t *table) Integrity(schema, table string, flags int) (err error) {
if flags&1 == 0 {
_, err = t.newReader().ReadAll()
}
return err
}
2023-11-23 03:28:56 +00:00
func (t *table) newReader() *csv.Reader {
2023-11-27 23:35:43 +00:00
if t.bom < 0 {
var bom [3]byte
t.r.ReadAt(bom[:], 0)
if string(bom[:]) == "\xEF\xBB\xBF" {
t.bom = 3
} else {
t.bom = 0
}
}
csv := csv.NewReader(io.NewSectionReader(t.r, int64(t.bom), math.MaxInt64))
2023-11-23 03:28:56 +00:00
csv.ReuseRecord = true
csv.Comma = t.comma
return csv
}
type cursor struct {
table *table
rowID int64
row []string
csv *csv.Reader
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.csv = c.table.newReader()
if c.table.header {
c.Next() // skip header
}
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() (err error) {
c.rowID++
c.row, err = c.csv.Read()
if err != io.EOF {
return err
}
return nil
}
func (c *cursor) EOF() bool {
return c.row == nil
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
if col < len(c.row) {
ctx.ResultText(c.row[col])
}
return nil
}