From 3f8b480ba0ed92b9c17797761666efc16a313ea6 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 14 Dec 2023 14:02:41 +0000 Subject: [PATCH] Optimize declared types. --- driver/driver.go | 50 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index eb29891..efd7503 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -377,7 +377,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv if err != nil { return nil, err } - return &rows{s, ctx}, nil + return &rows{ctx: ctx, stmt: s}, nil } func (s *stmt) setupBindings(args []driver.NamedValue) error { @@ -479,8 +479,10 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { } type rows struct { - *stmt ctx context.Context + *stmt + names []string + types []string } func (r *rows) Close() error { @@ -489,22 +491,35 @@ func (r *rows) Close() error { } func (r *rows) Columns() []string { - count := r.Stmt.ColumnCount() - columns := make([]string, count) - for i := range columns { - columns[i] = r.Stmt.ColumnName(i) + if r.names == nil { + count := r.Stmt.ColumnCount() + r.names = make([]string, count) + for i := range r.names { + r.names[i] = r.Stmt.ColumnName(i) + } } - return columns + return r.names +} + +func (r *rows) declType(index int) string { + if r.types == nil { + count := r.Stmt.ColumnCount() + r.types = make([]string, count) + for i := range r.types { + r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) + } + } + return r.types[index] } func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - decltype := r.Stmt.ColumnDeclType(index) + decltype := r.declType(index) if len := len(decltype); len > 0 && decltype[len-1] == ')' { if i := strings.LastIndexByte(decltype, '('); i >= 0 { decltype = decltype[:i] } } - return strings.ToUpper(strings.TrimSpace(decltype)) + return strings.TrimSpace(decltype) } func (r *rows) Next(dest []driver.Value) error { @@ -519,11 +534,12 @@ func (r *rows) Next(dest []driver.Value) error { } for i := range dest { - if t, ok := r.decodeTime(i); ok { - dest[i] = t + t := r.Stmt.ColumnType(i) + if tm, ok := r.decodeTime(i, t); ok { + dest[i] = tm continue } - switch r.Stmt.ColumnType(i) { + switch t { case sqlite3.INTEGER: dest[i] = r.Stmt.ColumnInt64(i) case sqlite3.FLOAT: @@ -542,21 +558,21 @@ func (r *rows) Next(dest []driver.Value) error { return r.Stmt.Err() } -func (s *stmt) decodeTime(i int) (_ time.Time, _ bool) { - if s.tmRead == sqlite3.TimeFormatDefault { +func (r *rows) decodeTime(i int, typ sqlite3.Datatype) (_ time.Time, _ bool) { + if r.tmRead == sqlite3.TimeFormatDefault { return } - switch s.Stmt.ColumnType(i) { + switch typ { case sqlite3.INTEGER, sqlite3.FLOAT, sqlite3.TEXT: // maybe default: return } - switch strings.ToUpper(s.Stmt.ColumnDeclType(i)) { + switch r.declType(i) { case "DATE", "TIME", "DATETIME", "TIMESTAMP": // maybe default: return } - return s.Stmt.ColumnTime(i, s.tmRead), s.Stmt.Err() == nil + return r.Stmt.ColumnTime(i, r.tmRead), r.Stmt.Err() == nil }