From 3b4df71a945a92d7b214475eee5931c09fd72e2a Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 21 Feb 2023 04:30:24 +0000 Subject: [PATCH] Time handling. --- .github/workflows/go.yml | 2 +- error.go | 1 + stmt.go | 46 +++++- stmt_test.go | 61 ++++++++ time.go | 292 +++++++++++++++++++++++++++++++++++++++ time_test.go | 114 +++++++++++++++ 6 files changed, 514 insertions(+), 2 deletions(-) create mode 100644 time.go create mode 100644 time_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index a3ae712..fd52b95 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -36,5 +36,5 @@ jobs: uses: ncruces/go-coverage-report@main if: | matrix.os == 'ubuntu-latest' && - github.event_name == 'push' + github.event_name == 'push' continue-on-error: true diff --git a/error.go b/error.go index fb4eb1d..8372af7 100644 --- a/error.go +++ b/error.go @@ -72,6 +72,7 @@ const ( noNulErr = errorString("sqlite3: missing NUL terminator") noGlobalErr = errorString("sqlite3: could not find global: ") noFuncErr = errorString("sqlite3: could not find function: ") + timeErr = errorString("sqlite3: invalid time value") ) func assertErr() errorString { diff --git a/stmt.go b/stmt.go index 419c295..bfb534a 100644 --- a/stmt.go +++ b/stmt.go @@ -2,6 +2,7 @@ package sqlite3 import ( "math" + "time" ) // Stmt is a prepared statement object. @@ -234,6 +235,24 @@ func (s *Stmt) BindNull(param int) error { return s.c.error(r[0]) } +// BindTime binds a [time.Time] to the prepared statement. +// The leftmost SQL parameter has an index of 1. +// +// https://www.sqlite.org/c3ref/bind_blob.html +func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error { + switch v := format.Encode(value).(type) { + case string: + s.BindText(param, v) + case int64: + s.BindInt64(param, v) + case float64: + s.BindFloat(param, v) + default: + panic(assertErr()) + } + return nil +} + // ColumnCount returns the number of columns in a result set. // // https://www.sqlite.org/c3ref/column_count.html @@ -259,7 +278,7 @@ func (s *Stmt) ColumnName(col int) string { ptr := uint32(r[0]) if ptr == 0 { - return "" + panic(oomErr) } return s.c.mem.readString(ptr, _MAX_STRING) } @@ -325,6 +344,31 @@ func (s *Stmt) ColumnFloat(col int) float64 { return math.Float64frombits(r[0]) } +// ColumnTime returns the value of the result column as a [time.Time]. +// The leftmost column of the result set has the index 0. +// +// https://www.sqlite.org/c3ref/column_blob.html +func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time { + var v any + switch s.ColumnType(col) { + case INTEGER: + v = s.ColumnInt64(col) + case FLOAT: + v = s.ColumnFloat(col) + case TEXT, BLOB: + v = s.ColumnText(col) + case NULL: + return time.Time{} + default: + panic(assertErr()) + } + t, err := format.Decode(v) + if err != nil { + s.err = err + } + return t +} + // ColumnText returns the value of the result column as a string. // The leftmost column of the result set has the index 0. // diff --git a/stmt_test.go b/stmt_test.go index 8d8cb91..45c8946 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -3,6 +3,7 @@ package sqlite3 import ( "math" "testing" + "time" ) func TestStmt(t *testing.T) { @@ -398,3 +399,63 @@ func TestStmt_BindName(t *testing.T) { } } } + +func TestStmt_Time(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare(`SELECT ?, ?, ?, datetime(), unixepoch(), julianday(), NULL, 'abc'`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + err = stmt.BindTime(1, reference, TimeFormat4) + if err != nil { + t.Fatal(err) + } + err = stmt.BindTime(2, reference, TimeFormatUnixMilli) + if err != nil { + t.Fatal(err) + } + err = stmt.BindTime(3, reference, TimeFormatJulianDay) + if err != nil { + t.Fatal(err) + } + + if now := time.Now(); stmt.Step() { + if got := stmt.ColumnTime(0, TimeFormatAuto); !reference.Equal(got) { + t.Errorf("got %v, want %v", got, reference) + } + if got := stmt.ColumnTime(1, TimeFormatAuto); !reference.Equal(got) { + t.Errorf("got %v, want %v", got, reference) + } + if got := stmt.ColumnTime(2, TimeFormatAuto); reference.Sub(got) > time.Millisecond { + t.Errorf("got %v, want %v", got, reference) + } + + if got := stmt.ColumnTime(3, TimeFormatAuto); now.Sub(got) > time.Second { + t.Errorf("got %v, want %v", got, now) + } + if got := stmt.ColumnTime(4, TimeFormatAuto); now.Sub(got) > time.Second { + t.Errorf("got %v, want %v", got, now) + } + if got := stmt.ColumnTime(5, TimeFormatAuto); now.Sub(got) > time.Millisecond { + t.Errorf("got %v, want %v", got, now) + } + + if got := stmt.ColumnTime(6, TimeFormatAuto); got != (time.Time{}) { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnTime(7, TimeFormatAuto); got != (time.Time{}) { + t.Errorf("got %v, want zero", got) + } + if stmt.Err() == nil { + t.Errorf("want error") + } + } +} diff --git a/time.go b/time.go new file mode 100644 index 0000000..8f12a9e --- /dev/null +++ b/time.go @@ -0,0 +1,292 @@ +package sqlite3 + +import ( + "math" + "strconv" + "strings" + "time" + + "github.com/ncruces/julianday" +) + +// TimeFormat specifies how to encode/decode time values. +// +// https://www.sqlite.org/lang_datefunc.html +type TimeFormat string + +const ( + TimeFormatDefault TimeFormat = "" // time.RFC3339Nano + + // Text formats + TimeFormat1 TimeFormat = "2006-01-02" + TimeFormat2 TimeFormat = "2006-01-02 15:04" + TimeFormat3 TimeFormat = "2006-01-02 15:04:05" + TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000" + TimeFormat5 TimeFormat = "2006-01-02T15:04" + TimeFormat6 TimeFormat = "2006-01-02T15:04:05" + TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000" + TimeFormat8 TimeFormat = "15:04" + TimeFormat9 TimeFormat = "15:04:05" + TimeFormat10 TimeFormat = "15:04:05.000" + + TimeFormat2TZ = TimeFormat2 + "Z07:00" + TimeFormat3TZ = TimeFormat3 + "Z07:00" + TimeFormat4TZ = TimeFormat4 + "Z07:00" + TimeFormat5TZ = TimeFormat5 + "Z07:00" + TimeFormat6TZ = TimeFormat6 + "Z07:00" + TimeFormat7TZ = TimeFormat7 + "Z07:00" + TimeFormat8TZ = TimeFormat8 + "Z07:00" + TimeFormat9TZ = TimeFormat9 + "Z07:00" + TimeFormat10TZ = TimeFormat10 + "Z07:00" + + // Numeric formats + TimeFormatJulianDay TimeFormat = "julianday" + TimeFormatUnix TimeFormat = "unixepoch" + TimeFormatUnixFrac TimeFormat = "unixepoch_frac" + TimeFormatUnixMilli TimeFormat = "unixepoch_milli" + TimeFormatUnixMicro TimeFormat = "unixepoch_micro" + TimeFormatUnixNano TimeFormat = "unixepoch_nano" + + // Auto + TimeFormatAuto TimeFormat = "auto" +) + +// Encode encodes a time value using this format. +// +// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano], +// preserving timezone, with nanosecond accuracy. +// +// https://www.sqlite.org/lang_datefunc.html +func (f TimeFormat) Encode(t time.Time) any { + switch f { + // Numeric formats + case TimeFormatJulianDay: + return julianday.Float(t) + case TimeFormatUnix: + return t.Unix() + case TimeFormatUnixFrac: + return float64(t.Unix()) + float64(t.Nanosecond())/1_000_000_000 + case TimeFormatUnixMilli: + return t.UnixMilli() + case TimeFormatUnixMicro: + return t.UnixMicro() + case TimeFormatUnixNano: + return t.UnixNano() + // Special formats + case TimeFormatDefault, TimeFormatAuto: + f = time.RFC3339Nano + } + // SQLite assumes UTC if unspecified. + if !strings.Contains(string(f), "Z07") && !strings.Contains(string(f), "-07") { + t = t.UTC() + } + return t.Format(string(f)) +} + +// Decode decodes a time value using this format. +// +// https://www.sqlite.org/lang_datefunc.html +func (f TimeFormat) Decode(v any) (time.Time, error) { + switch f { + // Numeric formats + case TimeFormatJulianDay: + switch v := v.(type) { + case string: + return julianday.Parse(v) + case float64: + return julianday.FloatTime(v), nil + case int64: + return julianday.Time(v, 0), nil + default: + return time.Time{}, timeErr + } + + case TimeFormatUnix, TimeFormatUnixFrac: + if s, ok := v.(string); ok { + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return time.Time{}, err + } + v = f + } + switch v := v.(type) { + case float64: + sec, frac := math.Modf(v) + nsec := math.Floor(frac * 1_000_000_000) + return time.Unix(int64(sec), int64(nsec)), nil + case int64: + return time.Unix(v, 0), nil + default: + return time.Time{}, timeErr + } + + case TimeFormatUnixMilli: + if s, ok := v.(string); ok { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return time.Time{}, err + } + v = i + } + switch v := v.(type) { + case float64: + return time.UnixMilli(int64(v)), nil + case int64: + return time.UnixMilli(int64(v)), nil + default: + return time.Time{}, timeErr + } + + case TimeFormatUnixMicro: + if s, ok := v.(string); ok { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return time.Time{}, err + } + v = i + } + switch v := v.(type) { + case float64: + return time.UnixMicro(int64(v)), nil + case int64: + return time.UnixMicro(int64(v)), nil + default: + return time.Time{}, timeErr + } + + case TimeFormatUnixNano: + if s, ok := v.(string); ok { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return time.Time{}, timeErr + } + v = i + } + switch v := v.(type) { + case float64: + return time.Unix(0, int64(v)), nil + case int64: + return time.Unix(0, int64(v)), nil + default: + return time.Time{}, timeErr + } + + // Special formats + case TimeFormatAuto: + switch s := v.(type) { + case string: + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + v = i + break + } + f, err := strconv.ParseFloat(s, 64) + if err == nil { + v = f + break + } + + dates := []TimeFormat{ + TimeFormat6TZ, TimeFormat6, TimeFormat3TZ, TimeFormat3, + TimeFormat5TZ, TimeFormat5, TimeFormat2TZ, TimeFormat2, + TimeFormat1, + } + for _, f := range dates { + t, err := time.Parse(string(f), s) + if err == nil { + return t, nil + } + } + + times := []TimeFormat{ + TimeFormat9TZ, TimeFormat9, TimeFormat8TZ, TimeFormat8, + } + for _, f := range times { + t, err := time.Parse(string(f), s) + if err == nil { + return t.AddDate(2000, 0, 0), nil + } + } + } + switch v := v.(type) { + case float64: + if 0 <= v && v < 5373484.5 { + return TimeFormatJulianDay.Decode(v) + } + if v < 253402300800 { + return TimeFormatUnixFrac.Decode(v) + } + if v < 253402300800_000 { + return TimeFormatUnixMilli.Decode(v) + } + if v < 253402300800_000000 { + return TimeFormatUnixMicro.Decode(v) + } + return TimeFormatUnixNano.Decode(v) + case int64: + if 0 <= v && v < 5373485 { + return TimeFormatJulianDay.Decode(v) + } + if v < 253402300800 { + return TimeFormatUnixFrac.Decode(v) + } + if v < 253402300800_000 { + return TimeFormatUnixMilli.Decode(v) + } + if v < 253402300800_000000 { + return TimeFormatUnixMicro.Decode(v) + } + return TimeFormatUnixNano.Decode(v) + default: + return time.Time{}, timeErr + } + + case + TimeFormat2, TimeFormat2TZ, + TimeFormat3, TimeFormat3TZ, + TimeFormat4, TimeFormat4TZ, + TimeFormat5, TimeFormat5TZ, + TimeFormat6, TimeFormat6TZ, + TimeFormat7, TimeFormat7TZ: + s, ok := v.(string) + if !ok { + return time.Time{}, timeErr + } + f := string(f) + f = strings.TrimSuffix(f, "Z07:00") + f = strings.TrimSuffix(f, ".000") + t, err := time.Parse(f+"Z07:00", s) + if err != nil { + t, err = time.Parse(f, s) + } + return t, err + + case + TimeFormat8, TimeFormat8TZ, + TimeFormat9, TimeFormat9TZ, + TimeFormat10, TimeFormat10TZ: + s, ok := v.(string) + if !ok { + return time.Time{}, timeErr + } + f := string(f) + f = strings.TrimSuffix(f, "Z07:00") + f = strings.TrimSuffix(f, ".000") + t, err := time.Parse(f+"Z07:00", s) + if err != nil { + t, err = time.Parse(f, s) + } + return t.AddDate(2000, 0, 0), err + + default: + s, ok := v.(string) + if !ok { + return time.Time{}, timeErr + } + f := string(f) + if f == "" { + f = time.RFC3339Nano + } + return time.Parse(f, s) + } +} diff --git a/time_test.go b/time_test.go new file mode 100644 index 0000000..af19cc5 --- /dev/null +++ b/time_test.go @@ -0,0 +1,114 @@ +package sqlite3 + +import ( + "reflect" + "testing" + "time" +) + +func TestTimeFormat_Encode(t *testing.T) { + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + + tests := []struct { + fmt TimeFormat + time time.Time + want any + }{ + {TimeFormatDefault, reference, "2013-10-07T04:23:19.12-04:00"}, + {TimeFormatJulianDay, reference, 2456572.849526851851852}, + {TimeFormatUnix, reference, int64(1381134199)}, + {TimeFormatUnixFrac, reference, 1381134199.120}, + {TimeFormatUnixMilli, reference, int64(1381134199_120)}, + {TimeFormatUnixMicro, reference, int64(1381134199_120000)}, + {TimeFormatUnixNano, reference, int64(1381134199_120000000)}, + {TimeFormat7, reference, "2013-10-07T08:23:19.120"}, + } + for _, tt := range tests { + t.Run("", func(t *testing.T) { + if got := tt.fmt.Encode(tt.time); !reflect.DeepEqual(got, tt.want) { + t.Errorf("%q.Encode(%v) = %v, want %v", tt.fmt, tt.time, got, tt.want) + } + }) + } +} + +func TestTimeFormat_Decode(t *testing.T) { + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + + tests := []struct { + fmt TimeFormat + val any + want time.Time + wantDelta time.Duration + wantErr bool + }{ + {TimeFormatJulianDay, "2456572.849526851851852", reference, 0, false}, + {TimeFormatJulianDay, 2456572.849526851851852, reference, time.Millisecond, false}, + {TimeFormatJulianDay, int64(2456572), reference, 24 * time.Hour, false}, + {TimeFormatJulianDay, false, time.Time{}, 0, true}, + + {TimeFormatUnix, "1381134199.120", reference, time.Microsecond, false}, + {TimeFormatUnix, 1381134199.120, reference, time.Microsecond, false}, + {TimeFormatUnix, int64(1381134199), reference, time.Second, false}, + {TimeFormatUnix, "abc", time.Time{}, 0, true}, + {TimeFormatUnix, false, time.Time{}, 0, true}, + + {TimeFormatUnixMilli, "1381134199120", reference, 0, false}, + {TimeFormatUnixMilli, 1381134199.120e3, reference, 0, false}, + {TimeFormatUnixMilli, int64(1381134199_120), reference, 0, false}, + {TimeFormatUnixMilli, "abc", time.Time{}, 0, true}, + {TimeFormatUnixMilli, false, time.Time{}, 0, true}, + + {TimeFormatUnixMicro, "1381134199120000", reference, 0, false}, + {TimeFormatUnixMicro, 1381134199.120e6, reference, 0, false}, + {TimeFormatUnixMicro, int64(1381134199_120000), reference, 0, false}, + {TimeFormatUnixMicro, "abc", time.Time{}, 0, true}, + {TimeFormatUnixMicro, false, time.Time{}, 0, true}, + + {TimeFormatUnixNano, "1381134199120000000", reference, 0, false}, + {TimeFormatUnixNano, 1381134199.120e9, reference, 0, false}, + {TimeFormatUnixNano, int64(1381134199_120000000), reference, 0, false}, + {TimeFormatUnixNano, "abc", time.Time{}, 0, true}, + {TimeFormatUnixNano, false, time.Time{}, 0, true}, + + {TimeFormatAuto, "2456572.849526851851852", reference, time.Millisecond, false}, + {TimeFormatAuto, "2456572", reference, 24 * time.Hour, false}, + {TimeFormatAuto, "1381134199.120", reference, time.Microsecond, false}, + {TimeFormatAuto, "1381134199.120e3", reference, time.Microsecond, false}, + {TimeFormatAuto, "1381134199.120e6", reference, time.Microsecond, false}, + {TimeFormatAuto, "1381134199.120e9", reference, time.Microsecond, false}, + {TimeFormatAuto, "1381134199", reference, time.Second, false}, + {TimeFormatAuto, "1381134199120", reference, 0, false}, + {TimeFormatAuto, "1381134199120000", reference, 0, false}, + {TimeFormatAuto, "1381134199120000000", reference, 0, false}, + {TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false}, + {TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false}, + {TimeFormatAuto, "abc", time.Time{}, 0, true}, + {TimeFormatAuto, false, time.Time{}, 0, true}, + + {TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false}, + {TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false}, + {TimeFormat9, "04:23:19.12-04:00", reftime, 0, false}, + {TimeFormat9, "08:23:19.12", reftime, 0, false}, + {TimeFormat3, false, time.Time{}, 0, true}, + {TimeFormat9, false, time.Time{}, 0, true}, + + {TimeFormatDefault, "2013-10-07T04:23:19.12-04:00", reference, 0, false}, + {TimeFormatDefault, "2013-10-07T08:23:19.12Z", reference, 0, false}, + {TimeFormatDefault, false, time.Time{}, 0, true}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + got, err := tt.fmt.Decode(tt.val) + if (err != nil) != tt.wantErr { + t.Errorf("%q.Decode(%v) error = %v, wantErr %v", tt.fmt, tt.val, err, tt.wantErr) + return + } + if tt.want.Sub(got).Abs() > tt.wantDelta { + t.Errorf("%q.Decode(%v) = %v, want %v", tt.fmt, tt.val, got, tt.want) + } + }) + } +}