diff --git a/ext/bloom/bloom.go b/ext/bloom/bloom.go index b71f90a..b203ada 100644 --- a/ext/bloom/bloom.go +++ b/ext/bloom/bloom.go @@ -16,6 +16,7 @@ import ( "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/sql3util" ) // Register registers the bloom_filter virtual table: @@ -55,11 +56,9 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, } if len(arg) > 1 { - b.prob, err = strconv.ParseFloat(arg[1], 64) - if err != nil { - return nil, err - } - if b.prob <= 0 || b.prob >= 1 { + var ok bool + b.prob, ok = sql3util.ParseFloat(arg[1]) + if !ok || b.prob <= 0 || b.prob >= 1 { return nil, util.ErrorString("bloom: probability must be in the range (0,1)") } } else { diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 2ea4871..14e9b89 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -254,19 +254,15 @@ func (c *cursor) Column(ctx sqlite3.Context, col int) error { switch typ { case numeric, integer: - if strings.TrimLeft(txt, "+-0123456789") == "" { - if i, err := strconv.ParseInt(txt, 10, 64); err == nil { - ctx.ResultInt64(i) - return nil - } + if i, err := strconv.ParseInt(txt, 10, 64); err == nil { + ctx.ResultInt64(i) + return nil } fallthrough case real: - if strings.TrimLeft(txt, "+-.0123456789Ee") == "" { - if f, err := strconv.ParseFloat(txt, 64); err == nil { - ctx.ResultFloat(f) - return nil - } + if f, ok := sql3util.ParseFloat(txt); ok { + ctx.ResultFloat(f) + return nil } fallthrough default: diff --git a/litestream/README.md b/litestream/README.md index 4c68253..ec46c85 100644 --- a/litestream/README.md +++ b/litestream/README.md @@ -4,3 +4,8 @@ This package implements the **EXPERIMENTAL** `"litestream"` SQLite VFS that offers Litestream [lightweight read-replicas](https://fly.io/blog/litestream-revamped/#lightweight-read-replicas). See the [example](example_test.go) for how to use. + +Our `PRAGMA litestream_time` accepts: +- Go [duration strings](https://pkg.go.dev/time#ParseDuration) +- SQLite [time values](https://sqlite.org/lang_datefunc.html#time_values) +- SQLite [time modifiers 1 through 13](https://sqlite.org/lang_datefunc.html#modifiers) diff --git a/litestream/time.go b/litestream/time.go new file mode 100644 index 0000000..8328cb9 --- /dev/null +++ b/litestream/time.go @@ -0,0 +1,63 @@ +package litestream + +import ( + "math" + "strings" + "time" + + "github.com/ncruces/go-sqlite3/util/sql3util" +) + +func parseTimeDelta(s string) (years, months, days int, duration time.Duration, ok bool) { + duration, err := time.ParseDuration(s) + if err == nil { + return 0, 0, 0, duration, true + } + + if strings.EqualFold(s, "now") { + return 0, 0, 0, 0, true + } + + ss := strings.TrimSuffix(strings.ToLower(s), "s") + switch { + case strings.HasSuffix(ss, " year"): + years, duration, ok = parseDateUnit(ss, " year", 365*86400) + + case strings.HasSuffix(ss, " month"): + months, duration, ok = parseDateUnit(ss, " month", 30*86400) + + case strings.HasSuffix(ss, " day"): + months, duration, ok = parseDateUnit(ss, " day", 86400) + + case strings.HasSuffix(ss, " hour"): + duration, ok = parseTimeUnit(ss, " hour", time.Hour) + + case strings.HasSuffix(ss, " minute"): + duration, ok = parseTimeUnit(ss, " minute", time.Minute) + + case strings.HasSuffix(ss, " second"): + duration, ok = parseTimeUnit(ss, " second", time.Second) + + default: + return sql3util.ParseTimeShift(s) + } + return +} + +func parseDateUnit(s, unit string, seconds float64) (int, time.Duration, bool) { + f, ok := sql3util.ParseFloat(s[:len(s)-len(unit)]) + if !ok { + return 0, 0, false + } + + i, f := math.Modf(f) + if math.MinInt <= i && i <= math.MaxInt { + return int(i), time.Duration(f * seconds * float64(time.Second)), true + } + return 0, 0, false +} + +func parseTimeUnit(s, unit string, scale time.Duration) (time.Duration, bool) { + f, ok := sql3util.ParseFloat(s[:len(s)-len(unit)]) + return time.Duration(f * float64(scale)), ok +} diff --git a/litestream/vfs.go b/litestream/vfs.go index 94c06fa..a116dbf 100644 --- a/litestream/vfs.go +++ b/litestream/vfs.go @@ -15,7 +15,6 @@ import ( "github.com/superfly/ltx" "github.com/ncruces/go-sqlite3" - "github.com/ncruces/go-sqlite3/util/sql3util" "github.com/ncruces/go-sqlite3/util/vfsutil" "github.com/ncruces/go-sqlite3/vfs" "github.com/ncruces/wbt" @@ -224,7 +223,7 @@ func (f *liteFile) Pragma(name, value string) (string, error) { } var syncTime time.Time - if years, months, days, duration, ok := sql3util.ParseTimeShift(value); ok { + if years, months, days, duration, ok := parseTimeDelta(value); ok { syncTime = time.Now().AddDate(years, months, days).Add(duration) } else { syncTime, _ = sqlite3.TimeFormatAuto.Decode(value) diff --git a/litestream/vfs_test.go b/litestream/vfs_test.go index dd2dd77..1a47a7e 100644 --- a/litestream/vfs_test.go +++ b/litestream/vfs_test.go @@ -91,7 +91,17 @@ func Test_integration(t *testing.T) { t.Errorf("got %q", txid) } - _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='00:01'`) + _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='-1.5h'`) + if err != nil { + t.Fatal(err) + } + + _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='-00:01'`) + if err != nil { + t.Fatal(err) + } + + _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='-2.5 years'`) if err != nil { t.Fatal(err) } diff --git a/time.go b/time.go index 19bcd2b..280c766 100644 --- a/time.go +++ b/time.go @@ -7,6 +7,7 @@ import ( "time" "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/sql3util" "github.com/ncruces/julianday" ) @@ -157,11 +158,13 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case TimeFormatUnix, TimeFormatUnixFrac: if s, ok := v.(string); ok { - f, err := strconv.ParseFloat(s, 64) - if err != nil { - return time.Time{}, err + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + v = i + } else if f, ok := sql3util.ParseFloat(s); ok { + v = f + } else { + return time.Time{}, util.TimeErr } - v = f } switch v := v.(type) { case float64: @@ -234,8 +237,8 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { v = i break } - f, err := strconv.ParseFloat(s, 64) - if err == nil { + f, ok := sql3util.ParseFloat(s) + if ok { v = f break } diff --git a/util/sql3util/arg.go b/util/sql3util/arg.go index 68fa8ab..6995891 100644 --- a/util/sql3util/arg.go +++ b/util/sql3util/arg.go @@ -1,6 +1,7 @@ package sql3util import ( + "strconv" "strings" "time" ) @@ -67,6 +68,15 @@ func ParseBool(s string) (b, ok bool) { return false, false } +// ParseFloat parses a decimal floating point number. +func ParseFloat(s string) (f float64, ok bool) { + if strings.TrimLeft(s, "+-.0123456789Ee") != "" { + return + } + f, err := strconv.ParseFloat(s, 64) + return f, err == nil +} + // ParseTimeShift parses a time shift modifier, // also the output of timediff. //