From ca93c498e7e899405b8eeeb9d0064f8c2c48a954 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 4 Dec 2025 15:27:16 +0000 Subject: [PATCH] Relative time, fixes. --- litestream/vfs.go | 30 +++++----- litestream/vfs_test.go | 19 +++++++ util/sql3util/arg.go | 117 +++++++++++++++++++++++++++++++++++++- util/sql3util/arg_test.go | 41 +++++++++++++ 4 files changed, 191 insertions(+), 16 deletions(-) diff --git a/litestream/vfs.go b/litestream/vfs.go index 55feeed..94c06fa 100644 --- a/litestream/vfs.go +++ b/litestream/vfs.go @@ -15,6 +15,7 @@ import ( "github.com/superfly/ltx" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/util/sql3util" "github.com/ncruces/go-sqlite3/util/vfsutil" "github.com/ncruces/go-sqlite3/vfs" "github.com/ncruces/wbt" @@ -212,7 +213,7 @@ func (f *liteFile) Pragma(name, value string) (string, error) { return syncTime.Format(time.RFC3339Nano), nil } - if !f.locked { + if f.locked { return "", sqlite3.MISUSE } @@ -222,12 +223,17 @@ func (f *liteFile) Pragma(name, value string) (string, error) { return "", nil } - syncTime, err := sqlite3.TimeFormatAuto.Decode(value) - if err != nil { - return "", err + var syncTime time.Time + if years, months, days, duration, ok := sql3util.ParseTimeShift(value); ok { + syncTime = time.Now().AddDate(years, months, days).Add(duration) + } else { + syncTime, _ = sqlite3.TimeFormatAuto.Decode(value) + } + if syncTime.IsZero() { + return "", sqlite3.MISUSE } - err = f.buildIndex(f.context(), syncTime) + err := f.buildIndex(f.context(), syncTime) if err != nil { f.db.opts.Logger.Error("build index", "error", err) } @@ -251,11 +257,8 @@ func (f *liteFile) context() context.Context { func (f *liteFile) buildIndex(ctx context.Context, syncTime time.Time) error { // Build the index from scratch from a Litestream restore plan. infos, err := litestream.CalcRestorePlan(ctx, f.db.client, 0, syncTime, f.db.opts.Logger) - if err != nil { - if !errors.Is(err, litestream.ErrTxNotAvailable) { - return fmt.Errorf("calc restore plan: %w", err) - } - return nil + if err != nil && !errors.Is(err, litestream.ErrTxNotAvailable) { + return fmt.Errorf("calc restore plan: %w", err) } var txid ltx.TXID @@ -295,11 +298,8 @@ func (d *liteDB) buildIndex(ctx context.Context) error { // Build the index from scratch from a Litestream restore plan. infos, err := litestream.CalcRestorePlan(ctx, d.client, 0, time.Time{}, d.opts.Logger) - if err != nil { - if !errors.Is(err, litestream.ErrTxNotAvailable) { - return fmt.Errorf("calc restore plan: %w", err) - } - return nil + if err != nil && !errors.Is(err, litestream.ErrTxNotAvailable) { + return fmt.Errorf("calc restore plan: %w", err) } for _, info := range infos { diff --git a/litestream/vfs_test.go b/litestream/vfs_test.go index dd150e3..7102bf8 100644 --- a/litestream/vfs_test.go +++ b/litestream/vfs_test.go @@ -93,6 +93,25 @@ func Test_integration(t *testing.T) { if txid != "0000000000000001" { t.Errorf("got %q", txid) } + + _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='00:01'`) + if err != nil { + t.Fatal(err) + } + + _, err = replica.ExecContext(t.Context(), `PRAGMA litestream_time='1970-01-01'`) + if err != nil { + t.Fatal(err) + } + + var sync time.Time + err = replica.QueryRowContext(t.Context(), `PRAGMA litestream_time`).Scan(&sync) + if err != nil { + t.Fatal(err) + } + if !sync.Equal(time.Unix(0, 0)) { + t.Errorf("got %v", sync) + } } func setupPrimary(tb testing.TB, path string, client ReplicaClient) error { diff --git a/util/sql3util/arg.go b/util/sql3util/arg.go index 3e8c728..ddb8aa8 100644 --- a/util/sql3util/arg.go +++ b/util/sql3util/arg.go @@ -1,6 +1,9 @@ package sql3util -import "strings" +import ( + "strings" + "time" +) // NamedArg splits an named arg into a key and value, // around an equals sign. @@ -63,3 +66,115 @@ func ParseBool(s string) (b, ok bool) { } return false, false } + +// ParseTimeShift parses a time shift modifier, +// also the output of timediff. +// +// https://sqlite.org/lang_datefunc.html +func ParseTimeShift(s string) (years, months, days int, duration time.Duration, ok bool) { + // Sign part: ± + neg := strings.HasPrefix(s, "-") + sign := neg || strings.HasPrefix(s, "+") + if sign { + s = s[1:] + } + + if ok = len(s) >= 5; !ok { + return // !ok + } + + defer func() { + if neg { + years = -years + months = -months + days = -days + duration = -duration + } + }() + + // Date part: YYYY-MM-DD + if s[4] == '-' { + if ok = sign && len(s) >= 10 && s[7] == '-'; !ok { + return // !ok + } + if years, ok = parseInt(s[0:4]); !ok { + return // !ok + } + if months, ok = parseInt(s[5:7]); !ok { + return // !ok + } + if days, ok = parseInt(s[8:10]); !ok { + return // !ok + } + if len(s) == 10 { + return + } + if ok = s[10] == ' '; !ok { + return // !ok + } + s = s[11:] + } + + // Time part: HH:MM + if ok = len(s) >= 5 && s[2] == ':'; !ok { + return // !ok + } + + var hours, minutes int + if hours, ok = parseInt(s[0:2]); !ok { + return + } + if minutes, ok = parseInt(s[3:5]); !ok { + return + } + duration = time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + + if len(s) == 5 { + return + } + if ok = len(s) >= 8 && s[5] == ':'; !ok { + return // !ok + } + + // Seconds part: HH:MM:SS + var seconds int + if seconds, ok = parseInt(s[6:8]); !ok { + return + } + duration += time.Duration(seconds) * time.Second + + if len(s) == 8 { + return + } + if ok = len(s) >= 10 && s[8] == '.'; !ok { + return // !ok + } + s = s[9:] + + // Nanosecond part: HH:MM:SS.SSS + var nanos int + if nanos, ok = parseInt(s[0:min(9, len(s))]); !ok { + return + } + for i := len(s); i < 9; i++ { + nanos *= 10 + } + duration += time.Duration(nanos) + + // Subnanosecond part. + if len(s) > 9 { + _, ok = parseInt(s[9:]) + } + return +} + +func parseInt(s string) (i int, _ bool) { + for _, r := range []byte(s) { + r -= '0' + if r > 9 { + return + } + i = i*10 + int(r) + } + return i, true +} diff --git a/util/sql3util/arg_test.go b/util/sql3util/arg_test.go index fcd34c0..a110cbc 100644 --- a/util/sql3util/arg_test.go +++ b/util/sql3util/arg_test.go @@ -2,6 +2,7 @@ package sql3util_test import ( "testing" + "time" "github.com/ncruces/go-sqlite3/util/sql3util" ) @@ -53,3 +54,43 @@ func TestParseBool(t *testing.T) { }) } } + +func TestParseTimeShift(t *testing.T) { + epoch := time.Unix(0, 0) + tests := []struct { + str string + val time.Time + ok bool + }{ + {"", epoch, false}, + {"0001-12-30", epoch, false}, + {"+_001-12-30", epoch, false}, + {"+0001-_2-30", epoch.AddDate(1, 0, 0), false}, + {"+0001-12-_0", epoch.AddDate(1, 12, 0), false}, + {"+0001-12-30", epoch.AddDate(1, 12, 30), true}, + {"-0001-12-30", epoch.AddDate(-1, -12, -30), true}, + {"+0001-12-30T", epoch.AddDate(1, 12, 30), false}, + {"+0001-12-30 12", epoch.AddDate(1, 12, 30), false}, + {"+0001-12-30 _2:30", epoch.AddDate(1, 12, 30), false}, + {"+0001-12-30 12:_0", epoch.AddDate(1, 12, 30), false}, + {"+0001-12-30 12:30", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 30*time.Minute), true}, + {"+0001-12-30 12:30:", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 30*time.Minute), false}, + {"+0001-12-30 12:30:_0", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 30*time.Minute), false}, + {"+0001-12-30 12:30:60", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute), true}, + {"+0001-12-30 12:30:60.", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute), false}, + {"+0001-12-30 12:30:60._", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute), false}, + {"+0001-12-30 12:30:60.1", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute + 100*time.Millisecond), true}, + {"+0001-12-30 12:30:60.123456789_", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute + 123456789), false}, + {"+0001-12-30 12:30:60.1234567890", epoch.AddDate(1, 12, 30).Add(12*time.Hour + 31*time.Minute + 123456789), true}, + {"-12:30:60.1234567890", epoch.Add(-12*time.Hour - 31*time.Minute - 123456789), true}, + } + for _, tt := range tests { + t.Run(tt.str, func(t *testing.T) { + years, months, days, duration, gotOK := sql3util.ParseTimeShift(tt.str) + gotVal := epoch.AddDate(years, months, days).Add(duration) + if !gotVal.Equal(tt.val) || gotOK != tt.ok { + t.Errorf("ParseTimeShift(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok) + } + }) + } +}