mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Refactor extensions.
This commit is contained in:
36
ext/csv/arg.go
Normal file
36
ext/csv/arg.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package csv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/util/vtabutil"
|
||||
)
|
||||
|
||||
func uintArg(key, val string) (int, error) {
|
||||
i, err := strconv.ParseUint(val, 10, 15)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
return int(i), nil
|
||||
}
|
||||
|
||||
func boolArg(key, val string) (bool, error) {
|
||||
if val == "" {
|
||||
return true, nil
|
||||
}
|
||||
b, ok := util.ParseBool(val)
|
||||
if ok {
|
||||
return b, nil
|
||||
}
|
||||
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
|
||||
func runeArg(key, val string) (rune, error) {
|
||||
r, _, tail, err := strconv.UnquoteChar(vtabutil.Unquote(val), 0)
|
||||
if tail != "" || err != nil {
|
||||
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
package csv
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
func Test_uintParam(t *testing.T) {
|
||||
"github.com/ncruces/go-sqlite3/util/vtabutil"
|
||||
)
|
||||
|
||||
func Test_uintArg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
@@ -20,22 +24,22 @@ func Test_uintParam(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.arg, func(t *testing.T) {
|
||||
key, val := getParam(tt.arg)
|
||||
key, val := vtabutil.NamedArg(tt.arg)
|
||||
if key != tt.key {
|
||||
t.Errorf("getParam() %v, want err %v", key, tt.key)
|
||||
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
|
||||
}
|
||||
got, err := uintParam(key, val)
|
||||
got, err := uintArg(key, val)
|
||||
if (err != nil) != tt.err {
|
||||
t.Fatalf("uintParam() error = %v, want err %v", err, tt.err)
|
||||
t.Fatalf("uintArg() error = %v, want err %v", err, tt.err)
|
||||
}
|
||||
if got != tt.val {
|
||||
t.Errorf("uintParam() = %v, want %v", got, tt.val)
|
||||
t.Errorf("uintArg() = %v, want %v", got, tt.val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_boolParam(t *testing.T) {
|
||||
func Test_boolArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
arg string
|
||||
key string
|
||||
@@ -56,22 +60,22 @@ func Test_boolParam(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.arg, func(t *testing.T) {
|
||||
key, val := getParam(tt.arg)
|
||||
key, val := vtabutil.NamedArg(tt.arg)
|
||||
if key != tt.key {
|
||||
t.Errorf("getParam() %v, want err %v", key, tt.key)
|
||||
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
|
||||
}
|
||||
got, err := boolParam(key, val)
|
||||
got, err := boolArg(key, val)
|
||||
if (err != nil) != tt.err {
|
||||
t.Fatalf("boolParam() error = %v, want err %v", err, tt.err)
|
||||
t.Fatalf("boolArg() error = %v, want err %v", err, tt.err)
|
||||
}
|
||||
if got != tt.val {
|
||||
t.Errorf("boolParam() = %v, want %v", got, tt.val)
|
||||
t.Errorf("boolArg() = %v, want %v", got, tt.val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_runeParam(t *testing.T) {
|
||||
func Test_runeArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
arg string
|
||||
key string
|
||||
@@ -88,16 +92,16 @@ func Test_runeParam(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.arg, func(t *testing.T) {
|
||||
key, val := getParam(tt.arg)
|
||||
key, val := vtabutil.NamedArg(tt.arg)
|
||||
if key != tt.key {
|
||||
t.Errorf("getParam() %v, want err %v", key, tt.key)
|
||||
t.Errorf("NamedArg() %v, want err %v", key, tt.key)
|
||||
}
|
||||
got, err := runeParam(key, val)
|
||||
got, err := runeArg(key, val)
|
||||
if (err != nil) != tt.err {
|
||||
t.Fatalf("runeParam() error = %v, want err %v", err, tt.err)
|
||||
t.Fatalf("runeArg() error = %v, want err %v", err, tt.err)
|
||||
}
|
||||
if got != tt.val {
|
||||
t.Errorf("runeParam() = %v, want %v", got, tt.val)
|
||||
t.Errorf("runeArg() = %v, want %v", got, tt.val)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -15,13 +15,14 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/util/fsutil"
|
||||
"github.com/ncruces/go-sqlite3/util/vtabutil"
|
||||
)
|
||||
|
||||
// Register registers the CSV virtual table.
|
||||
// If a filename is specified, [os.Open] is used to open the file.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
RegisterFS(db, util.OSFS{})
|
||||
RegisterFS(db, fsutil.OSFS{})
|
||||
}
|
||||
|
||||
// RegisterFS registers the CSV virtual table.
|
||||
@@ -40,23 +41,23 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
)
|
||||
|
||||
for _, arg := range arg {
|
||||
key, val := getParam(arg)
|
||||
key, val := vtabutil.NamedArg(arg)
|
||||
if _, ok := done[key]; ok {
|
||||
return nil, fmt.Errorf("csv: more than one %q parameter", key)
|
||||
}
|
||||
switch key {
|
||||
case "filename":
|
||||
filename = unquoteParam(val)
|
||||
filename = vtabutil.Unquote(val)
|
||||
case "data":
|
||||
data = unquoteParam(val)
|
||||
data = vtabutil.Unquote(val)
|
||||
case "schema":
|
||||
schema = unquoteParam(val)
|
||||
schema = vtabutil.Unquote(val)
|
||||
case "header":
|
||||
header, err = boolParam(key, val)
|
||||
header, err = boolArg(key, val)
|
||||
case "columns":
|
||||
columns, err = uintParam(key, val)
|
||||
columns, err = uintArg(key, val)
|
||||
case "comma":
|
||||
comma, err = runeParam(key, val)
|
||||
comma, err = runeArg(key, val)
|
||||
default:
|
||||
return nil, fmt.Errorf("csv: unknown %q parameter", key)
|
||||
}
|
||||
@@ -81,8 +82,8 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
|
||||
if schema == "" {
|
||||
var row []string
|
||||
if header || columns < 0 {
|
||||
csv, close, err := table.newReader()
|
||||
defer close.Close()
|
||||
csv, c, err := table.newReader()
|
||||
defer c.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -133,13 +134,11 @@ func (t *table) Integrity(schema, table string, flags int) error {
|
||||
if flags&1 != 0 {
|
||||
return nil
|
||||
}
|
||||
csv, close, err := t.newReader()
|
||||
csv, c, err := t.newReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if close != nil {
|
||||
defer close.Close()
|
||||
}
|
||||
defer c.Close()
|
||||
_, err = csv.ReadAll()
|
||||
return err
|
||||
}
|
||||
@@ -176,20 +175,28 @@ func (t *table) newReader() (*csv.Reader, io.Closer, error) {
|
||||
}
|
||||
|
||||
type cursor struct {
|
||||
table *table
|
||||
close io.Closer
|
||||
csv *csv.Reader
|
||||
row []string
|
||||
rowID int64
|
||||
table *table
|
||||
closer io.Closer
|
||||
csv *csv.Reader
|
||||
row []string
|
||||
rowID int64
|
||||
}
|
||||
|
||||
func (c *cursor) Close() error {
|
||||
return c.close.Close()
|
||||
func (c *cursor) Close() (err error) {
|
||||
if c.closer != nil {
|
||||
err = c.closer.Close()
|
||||
c.closer = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
|
||||
var err error
|
||||
c.csv, c.close, err = c.table.newReader()
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.csv, c.closer, err = c.table.newReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package csv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getParam(arg string) (key, val string) {
|
||||
key, val, _ = strings.Cut(arg, "=")
|
||||
key = strings.TrimSpace(key)
|
||||
val = strings.TrimSpace(val)
|
||||
return
|
||||
}
|
||||
|
||||
func uintParam(key, val string) (int, error) {
|
||||
i, err := strconv.ParseUint(val, 10, 15)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
return int(i), nil
|
||||
}
|
||||
|
||||
func boolParam(key, val string) (bool, error) {
|
||||
if val == "" || val == "1" ||
|
||||
strings.EqualFold(val, "true") ||
|
||||
strings.EqualFold(val, "yes") ||
|
||||
strings.EqualFold(val, "on") {
|
||||
return true, nil
|
||||
}
|
||||
if val == "0" ||
|
||||
strings.EqualFold(val, "false") ||
|
||||
strings.EqualFold(val, "no") ||
|
||||
strings.EqualFold(val, "off") {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
|
||||
func runeParam(key, val string) (rune, error) {
|
||||
r, _, tail, err := strconv.UnquoteChar(unquoteParam(val), 0)
|
||||
if tail != "" || err != nil {
|
||||
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func unquoteParam(val string) string {
|
||||
if len(val) < 2 {
|
||||
return val
|
||||
}
|
||||
if val[0] != val[len(val)-1] {
|
||||
return val
|
||||
}
|
||||
var old, new string
|
||||
switch val[0] {
|
||||
default:
|
||||
return val
|
||||
case '"':
|
||||
old, new = `""`, `"`
|
||||
case '\'':
|
||||
old, new = `''`, `'`
|
||||
}
|
||||
return strings.ReplaceAll(val[1:len(val)-1], old, new)
|
||||
}
|
||||
Reference in New Issue
Block a user