From 06eaf41c4f9b70e91c7f9bc580eb4f6cab79587a Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 16 Sep 2024 12:05:37 +0100 Subject: [PATCH] Fix #151. --- driver/driver.go | 28 +++++++++++++++++++++++++++- tests/driver_test.go | 14 ++++++++++---- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index c6c758a..25fadca 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -578,6 +578,7 @@ type rows struct { *stmt names []string types []string + nulls []bool } func (r *rows) Close() error { @@ -596,6 +597,22 @@ func (r *rows) Columns() []string { return r.names } +func (r *rows) loadTypes() { + if r.nulls == nil { + count := r.Stmt.ColumnCount() + r.nulls = make([]bool, count) + r.types = make([]string, count) + for i := range r.nulls { + if col := r.Stmt.ColumnOriginName(i); col != "" { + r.types[i], _, r.nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( + r.Stmt.ColumnDatabaseName(i), + r.Stmt.ColumnTableName(i), + col) + } + } + } +} + func (r *rows) declType(index int) string { if r.types == nil { count := r.Stmt.ColumnCount() @@ -608,7 +625,8 @@ func (r *rows) declType(index int) string { } func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - decltype := r.declType(index) + r.loadTypes() + decltype := r.types[index] if len := len(decltype); len > 0 && decltype[len-1] == ')' { if i := strings.LastIndexByte(decltype, '('); i >= 0 { decltype = decltype[:i] @@ -617,6 +635,14 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { return strings.TrimSpace(decltype) } +func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + r.loadTypes() + if r.nulls[index] { + return false, true + } + return true, false +} + func (r *rows) Next(dest []driver.Value) error { old := r.Stmt.Conn().SetInterrupt(r.ctx) defer r.Stmt.Conn().SetInterrupt(old) diff --git a/tests/driver_test.go b/tests/driver_test.go index 777159f..54944a0 100644 --- a/tests/driver_test.go +++ b/tests/driver_test.go @@ -33,7 +33,7 @@ func TestDriver(t *testing.T) { defer conn.Close() res, err := conn.ExecContext(ctx, - `CREATE TABLE users (id INT, name VARCHAR(10))`) + `CREATE TABLE users (id INTEGER PRIMARY KEY NOT NULL, name VARCHAR(10))`) if err != nil { t.Fatal(err) } @@ -82,11 +82,17 @@ func TestDriver(t *testing.T) { if err != nil { t.Fatal(err) } - if got := typs[0].DatabaseTypeName(); got != "INT" { - t.Errorf("got %s, want INT", got) + if got := typs[0].DatabaseTypeName(); got != "INTEGER" { + t.Errorf("got %s, want INTEGER", got) } if got := typs[1].DatabaseTypeName(); got != "VARCHAR" { - t.Errorf("got %s, want INT", got) + t.Errorf("got %s, want VARCHAR", got) + } + if got, ok := typs[0].Nullable(); got || !ok { + t.Errorf("got %v/%v, want false/true", got, ok) + } + if got, ok := typs[1].Nullable(); !got || ok { + t.Errorf("got %v/%v, want true/false", got, ok) } row := 0