Reuse statement, API.

This commit is contained in:
Nuno Cruces
2023-12-04 12:37:53 +00:00
parent 8a0baedc10
commit cd40213898
12 changed files with 144 additions and 83 deletions

View File

@@ -172,9 +172,9 @@ func (t *table) newReader() *csv.Reader {
type cursor struct {
table *table
rowID int64
row []string
csv *csv.Reader
row []string
rowID int64
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {

View File

@@ -52,11 +52,11 @@ func (l lines) Open() (sqlite3.VTabCursor, error) {
}
type cursor struct {
reader bool
scanner *bufio.Scanner
closer io.Closer
rowID int64
eof bool
reader bool
}
func (c *cursor) Close() (err error) {

View File

@@ -26,12 +26,10 @@ func Register(db *sqlite3.Conn) {
sql = sql[1 : len-1]
}
table := &table{
db: db,
sql: sql,
}
err = table.declare()
table := &table{sql: sql}
err = table.declare(db)
if err != nil {
table.Close()
return nil, err
}
return table, nil
@@ -41,42 +39,40 @@ func Register(db *sqlite3.Conn) {
}
type table struct {
db *sqlite3.Conn
sql string
inputs int
outputs int
stmt *sqlite3.Stmt
sql string
inuse bool
}
func (t *table) declare() error {
stmt, tail, err := t.db.Prepare(t.sql)
func (t *table) declare(db *sqlite3.Conn) (err error) {
var tail string
t.stmt, tail, err = db.Prepare(t.sql)
if err != nil {
return err
}
defer stmt.Close()
if tail != "" {
return fmt.Errorf("statement: multiple statements")
}
if !stmt.ReadOnly() {
if !t.stmt.ReadOnly() {
return fmt.Errorf("statement: statement must be read only")
}
t.inputs = stmt.BindCount()
t.outputs = stmt.ColumnCount()
var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
for i := 0; i < t.outputs; i++ {
outputs := t.stmt.ColumnCount()
for i := 0; i < outputs; i++ {
str.WriteString(sep)
name := stmt.ColumnName(i)
name := t.stmt.ColumnName(i)
str.WriteString(sqlite3.QuoteIdentifier(name))
str.WriteByte(' ')
str.WriteString(stmt.ColumnDeclType(i))
str.WriteString(t.stmt.ColumnDeclType(i))
sep = ","
}
for i := 1; i <= t.inputs; i++ {
inputs := t.stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
name := t.stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
@@ -87,22 +83,24 @@ func (t *table) declare() error {
}
sep = ","
}
str.WriteByte(')')
return t.db.DeclareVtab(str.String())
return db.DeclareVtab(str.String())
}
func (t *table) Close() error {
return t.stmt.Close()
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.OrderByConsumed = false
idx.EstimatedCost = 1
idx.EstimatedRows = 1
idx.EstimatedCost = 1000
var argvIndex = 1
var needIndex bool
var listIndex []int
outputs := t.stmt.ColumnCount()
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < t.outputs {
if cst.Column < outputs {
continue
}
@@ -114,7 +112,7 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - t.outputs + 1
nextIndex := cst.Column - outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
@@ -136,10 +134,15 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt, _, err := t.db.Prepare(t.sql)
if err != nil {
return nil, err
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
}
return &cursor{table: t, stmt: stmt}, nil
}
@@ -153,26 +156,29 @@ type cursor struct {
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
done bool
}
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
c.stmt.ClearBindings()
return c.stmt.Reset()
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
if err := c.stmt.ClearBindings(); err != nil {
return err
}
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
return err
}
var list []int
if idxStr != "" {
err := json.Unmarshal([]byte(idxStr), &list)
buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
err := json.Unmarshal(buf, &list)
if err != nil {
return err
}
@@ -196,12 +202,11 @@ func (c *cursor) Next() error {
c.rowID++
return nil
}
c.done = true
return c.stmt.Err()
}
func (c *cursor) EOF() bool {
return c.done
return !c.stmt.Busy()
}
func (c *cursor) RowID() (int64, error) {
@@ -209,10 +214,11 @@ func (c *cursor) RowID() (int64, error) {
}
func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
if col < c.table.outputs {
switch outputs := c.stmt.ColumnCount(); {
case col < outputs:
ctx.ResultValue(c.stmt.ColumnValue(col))
} else if col-c.table.outputs < len(c.arg) {
ctx.ResultValue(c.arg[col-c.table.outputs])
case col-outputs < len(c.arg):
ctx.ResultValue(c.arg[col-outputs])
}
return nil
}