diff --git a/stmt_test.go b/stmt_test.go index 3acc319..54d7fa0 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -133,6 +133,7 @@ func TestStmt(t *testing.T) { if err != nil { t.Fatal(err) } + defer stmt.Close() if stmt.Step() { if got := stmt.ColumnType(0); got != INTEGER { diff --git a/tests/db_test.go b/tests/db_test.go index 103175c..0f31367 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -37,6 +37,7 @@ func testDB(t *testing.T, name string) { if err != nil { t.Fatal(err) } + defer stmt.Close() row := 0 ids := []int{0, 1, 2} diff --git a/tests/parallel_test.go b/tests/parallel_test.go index 1c4a4c9..57e2a73 100644 --- a/tests/parallel_test.go +++ b/tests/parallel_test.go @@ -14,7 +14,9 @@ import ( ) func TestParallel(t *testing.T) { - testParallel(t, t.TempDir(), 100) + name := filepath.Join(t.TempDir(), "test.db") + testParallel(t, name, 100) + testIntegrity(t, name) } func TestMultiProcess(t *testing.T) { @@ -22,10 +24,10 @@ func TestMultiProcess(t *testing.T) { return } - dir := t.TempDir() - t.Setenv("TestParallel_dir", dir) + name := filepath.Join(t.TempDir(), "test.db") + t.Setenv("TestMultiProcess_dbname", name) + cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess") - cmd.Stderr = os.Stderr out, err := cmd.StdoutPipe() if err != nil { t.Fatal(err) @@ -40,22 +42,25 @@ func TestMultiProcess(t *testing.T) { t.Fatal(err) } - testParallel(t, dir, 1000) - cmd.Wait() + testParallel(t, name, 1000) + if err := cmd.Wait(); err != nil { + t.Fatal(err) + } + testIntegrity(t, name) } func TestChildProcess(t *testing.T) { - dir := os.Getenv("TestParallel_dir") - if dir == "" || testing.Short() { + name := os.Getenv("TestMultiProcess_dbname") + if name == "" || testing.Short() { return } - testParallel(t, dir, 1000) + testParallel(t, name, 1000) } -func testParallel(t *testing.T, dir string, n int) { +func testParallel(t *testing.T, name string, n int) { writer := func() error { - db, err := sqlite3.Open(filepath.Join(dir, "test.db")) + db, err := sqlite3.Open(name) if err != nil { return err } @@ -83,7 +88,7 @@ func testParallel(t *testing.T, dir string, n int) { } reader := func() error { - db, err := sqlite3.Open(filepath.Join(dir, "test.db")) + db, err := sqlite3.Open(name) if err != nil { return err } @@ -101,6 +106,7 @@ func testParallel(t *testing.T, dir string, n int) { if err != nil { return err } + defer stmt.Close() row := 0 for stmt.Step() { @@ -140,3 +146,41 @@ func testParallel(t *testing.T, dir string, n int) { t.Fatal(err) } } + +func testIntegrity(t *testing.T, name string) { + db, err := sqlite3.Open(name) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + test := `PRAGMA integrity_check` + if testing.Short() { + test = `PRAGMA quick_check` + } + + stmt, _, err := db.Prepare(test) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + for stmt.Step() { + if row := stmt.ColumnText(0); row != "ok" { + t.Error(row) + } + } + if err := stmt.Err(); err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } +}