diff --git a/conn.go b/conn.go index cf7b46f..5dfaf9e 100644 --- a/conn.go +++ b/conn.go @@ -66,32 +66,36 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { type connKey struct{} -func newConn(ctx context.Context, filename string, flags OpenFlag) (conn *Conn, err error) { - err = ctx.Err() +func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) { + err := ctx.Err() if err != nil { return nil, err } - sqlite, err := instantiateSQLite() + + c := &Conn{interrupt: ctx} + c.sqlite, err = instantiateSQLite() if err != nil { return nil, err } defer func() { - if err != nil { - conn.Close() - conn = nil - sqlite.close() + if res == nil { + c.Close() + c.sqlite.close() + } else { + c.interrupt = context.Background() } }() - c := &Conn{sqlite: sqlite, interrupt: ctx} c.ctx = context.WithValue(c.ctx, connKey{}, c) c.arena = c.newArena(1024) c.handle, err = c.openDB(filename, flags) if err == nil { err = initExtensions(c) } - c.interrupt = context.Background() - return c, err + if err != nil { + return nil, err + } + return c, nil } func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { diff --git a/driver/driver.go b/driver/driver.go index f2568e8..88c4c50 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -202,7 +202,7 @@ func (n *connector) Driver() driver.Driver { return n.driver } -func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { +func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { c := &conn{ txLock: n.txLock, tmRead: n.tmRead, @@ -214,8 +214,8 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { return nil, err } defer func() { - if err != nil { - c.Conn.Close() + if res == nil { + c.Close() } }() @@ -239,6 +239,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { if err != nil { return nil, err } + defer s.Close() if s.Step() && s.ColumnBool(0) { c.readOnly = '1' } else { diff --git a/ext/bloom/bloom.go b/ext/bloom/bloom.go index ec8bbd5..e807bf9 100644 --- a/ext/bloom/bloom.go +++ b/ext/bloom/bloom.go @@ -34,7 +34,7 @@ type bloom struct { } func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) { - t := bloom{ + b := bloom{ db: db, schema: schema, storage: table + "_storage", @@ -54,30 +54,30 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, } if len(arg) > 1 { - t.prob, err = strconv.ParseFloat(arg[1], 64) + b.prob, err = strconv.ParseFloat(arg[1], 64) if err != nil { return nil, err } - if t.prob <= 0 || t.prob >= 1 { + if b.prob <= 0 || b.prob >= 1 { return nil, util.ErrorString("bloom: probability must be in the range (0,1)") } } else { - t.prob = 0.01 + b.prob = 0.01 } if len(arg) > 2 { - t.hashes, err = strconv.Atoi(arg[2]) + b.hashes, err = strconv.Atoi(arg[2]) if err != nil { return nil, err } - if t.hashes <= 0 { + if b.hashes <= 0 { return nil, util.ErrorString("bloom: number of hash functions must be positive") } } else { - t.hashes = max(1, numHashes(t.prob)) + b.hashes = max(1, numHashes(b.prob)) } - t.bytes = numBytes(nelem, t.prob) + b.bytes = numBytes(nelem, b.prob) err = db.DeclareVTab( `CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`) @@ -87,7 +87,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err = db.Exec(fmt.Sprintf( `CREATE TABLE %s.%s (data BLOB, p REAL, n INTEGER, m INTEGER, k INTEGER)`, - sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage))) + sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage))) if err != nil { return nil, err } @@ -98,17 +98,17 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err = db.Exec(fmt.Sprintf( `INSERT INTO %s.%s (rowid, data, p, n, m, k) VALUES (1, zeroblob(%d), %f, %d, %d, %d)`, - sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage), - t.bytes, t.prob, nelem, 8*t.bytes, t.hashes)) + sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage), + b.bytes, b.prob, nelem, 8*b.bytes, b.hashes)) if err != nil { - t.Destroy() + b.Destroy() return nil, err } - return &t, nil + return &b, nil } func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) { - t := bloom{ + b := bloom{ db: db, schema: schema, storage: table + "_storage", @@ -122,7 +122,7 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom load, _, err := db.Prepare(fmt.Sprintf( `SELECT m/8, p, k FROM %s.%s WHERE rowid = 1`, - sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage))) + sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage))) if err != nil { return nil, err } @@ -135,10 +135,10 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom return nil, sqlite3.CORRUPT_VTAB } - t.bytes = load.ColumnInt64(0) - t.prob = load.ColumnFloat(1) - t.hashes = load.ColumnInt(2) - return &t, nil + b.bytes = load.ColumnInt64(0) + b.prob = load.ColumnFloat(1) + b.hashes = load.ColumnInt(2) + return &b, nil } func (b *bloom) Destroy() error { diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 097380e..83feaf9 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error { // RegisterFS registers the CSV virtual table. // If a filename is specified, fsys is used to open the file. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { - declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { + declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { var ( filename string data string @@ -76,7 +76,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { return nil, util.ErrorString(`csv: must specify either "filename" or "data" but not both`) } - table := &table{ + t := &table{ fsys: fsys, name: filename, data: data, @@ -88,7 +88,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { if schema == "" { var row []string if header || columns < 0 { - csv, c, err := table.newReader() + csv, c, err := t.newReader() defer c.Close() if err != nil { return nil, err @@ -100,22 +100,20 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { } schema = getSchema(header, columns, row) } else { - defer func() { - if err == nil { - table.typs, err = getColumnAffinities(schema) - } - }() + t.typs, err = getColumnAffinities(schema) + if err != nil { + return nil, err + } } err = db.DeclareVTab(schema) + if err == nil { + err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + } if err != nil { return nil, err } - err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY) - if err != nil { - return nil, err - } - return table, nil + return t, nil } return sqlite3.CreateModule(db, "csv", declare, declare) diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index c46b2b9..234abee 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -31,7 +31,9 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode), sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) { err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + if err == nil { + err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + } return fsdir{fsys}, err })) } diff --git a/ext/lines/lines.go b/ext/lines/lines.go index ac5abd1..b38bed7 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -39,13 +39,17 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { sqlite3.CreateModule(db, "lines", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_INNOCUOUS) + if err == nil { + err = db.VTabConfig(sqlite3.VTAB_INNOCUOUS) + } return lines{}, err }), sqlite3.CreateModule(db, "lines_read", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + if err == nil { + err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + } return lines{fsys}, err })) } diff --git a/ext/pivot/pivot.go b/ext/pivot/pivot.go index 3c35f8b..015c2d2 100644 --- a/ext/pivot/pivot.go +++ b/ext/pivot/pivot.go @@ -25,15 +25,15 @@ type table struct { cols []*sqlite3.Value } -func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { +func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { if len(arg) != 3 { return nil, fmt.Errorf("pivot: wrong number of arguments") } - table := &table{db: db} + t := &table{db: db} defer func() { - if err != nil { - table.Close() + if res == nil { + t.Close() } }() @@ -42,17 +42,17 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err create.WriteString("CREATE TABLE x(") // Row key query. - table.scan = "SELECT * FROM\n" + arg[0] - stmt, _, err := db.Prepare(table.scan) + t.scan = "SELECT * FROM\n" + arg[0] + stmt, _, err := db.Prepare(t.scan) if err != nil { return nil, err } defer stmt.Close() - table.keys = make([]string, stmt.ColumnCount()) - for i := range table.keys { + t.keys = make([]string, stmt.ColumnCount()) + for i := range t.keys { name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) - table.keys[i] = name + t.keys[i] = name create.WriteString(sep) create.WriteString(name) sep = "," @@ -70,15 +70,15 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err } for stmt.Step() { name := sqlite3.QuoteIdentifier(stmt.ColumnText(1)) - table.cols = append(table.cols, stmt.ColumnValue(0).Dup()) + t.cols = append(t.cols, stmt.ColumnValue(0).Dup()) create.WriteString(",") create.WriteString(name) } stmt.Close() // Pivot cell query. - table.cell = "SELECT * FROM\n" + arg[2] - stmt, _, err = db.Prepare(table.cell) + t.cell = "SELECT * FROM\n" + arg[2] + stmt, _, err = db.Prepare(t.cell) if err != nil { return nil, err } @@ -86,8 +86,8 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err if stmt.ColumnCount() != 1 { return nil, util.ErrorString("pivot: cell query expects 1 result columns") } - if stmt.BindCount() != len(table.keys)+1 { - return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1) + if stmt.BindCount() != len(t.keys)+1 { + return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(t.keys)+1) } create.WriteByte(')') @@ -95,7 +95,7 @@ func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err err if err != nil { return nil, err } - return table, nil + return t, nil } func (t *table) Close() error { diff --git a/go.work.sum b/go.work.sum index 085f015..52265b5 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,10 +1,13 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=