diff --git a/config.go b/config.go index 17166b9..d9ce9f5 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strconv" + "sync/atomic" "github.com/tetratelabs/wazero/api" @@ -48,6 +49,15 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { return util.ReadBool(c.mod, argsPtr), c.error(rc) } +var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)] + +// ConfigLog sets up the default error logging callback for new connections. +// +// https://sqlite.org/errlog.html +func ConfigLog(cb func(code ExtendedErrorCode, msg string)) { + defaultLogger.Store(&cb) +} + // ConfigLog sets up the error logging callback for the connection. // // https://sqlite.org/errlog.html diff --git a/conn.go b/conn.go index 9f9251e..8814d5a 100644 --- a/conn.go +++ b/conn.go @@ -92,6 +92,9 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ }() c.ctx = context.WithValue(c.ctx, connKey{}, c) + if logger := defaultLogger.Load(); logger != nil { + c.ConfigLog(*logger) + } c.arena = c.newArena() c.handle, err = c.openDB(filename, flags) if err == nil { diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 15129a9..6cdb682 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -1,7 +1,6 @@ package tests import ( - "errors" "fmt" "io" "log" @@ -24,15 +23,17 @@ import ( func TestMain(m *testing.M) { sqlite3.Initialize() - sqlite3.AutoExtension(func(c *sqlite3.Conn) error { - return c.ConfigLog(func(code sqlite3.ExtendedErrorCode, msg string) { - // Having to do journal recovery is unexpected. - if errors.Is(code, sqlite3.NOTICE) { - log.Panicf("%v (%d): %s", code, code, msg) - } else { - log.Printf("%v (%d): %s", code, code, msg) - } - }) + sqlite3.ConfigLog(func(code sqlite3.ExtendedErrorCode, msg string) { + switch code { + case sqlite3.NOTICE_RECOVER_WAL: + // Wal "recovery" is expected. + break + case sqlite3.NOTICE_RECOVER_ROLLBACK: + // Rollback journal recovery is an error. + log.Panicf("%v (%d): %s", code, code, msg) + default: + log.Printf("%v (%d): %s", code, code, msg) + } }) m.Run() } @@ -68,7 +69,7 @@ func Test_wal(t *testing.T) { if testing.Short() { iter = 1000 } else { - iter = 2500 + iter = 5000 } name := "file:" +