diff --git a/ext/zorder/zorder.go b/ext/zorder/zorder.go index 627b632..f95c605 100644 --- a/ext/zorder/zorder.go +++ b/ext/zorder/zorder.go @@ -19,9 +19,9 @@ func Register(db *sqlite3.Conn) error { } func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) { - var x [63]int64 - if len(arg) > len(x) { - ctx.ResultError(util.ErrorString("zorder: too many parameters")) + var x [24]int64 + if n := len(arg); n < 2 || n > 24 { + ctx.ResultError(util.ErrorString("zorder: needs between 2 and 24 dimensions")) return } for i := range arg { @@ -29,17 +29,15 @@ func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) { } var z int64 - if len(arg) > 0 { - for i := range x { - j := i % len(arg) - z |= (x[j] & 1) << i - x[j] >>= 1 - } + for i := range 63 { + j := i % len(arg) + z |= (x[j] & 1) << i + x[j] >>= 1 } for i := range arg { if x[i] != 0 { - ctx.ResultError(util.ErrorString("zorder: parameter too large")) + ctx.ResultError(util.ErrorString("zorder: argument out of range")) return } } @@ -51,6 +49,19 @@ func unzorder(ctx sqlite3.Context, arg ...sqlite3.Value) { n := arg[1].Int64() z := arg[0].Int64() + if n < 2 || n > 24 { + ctx.ResultError(util.ErrorString("unzorder: needs between 2 and 24 dimensions")) + return + } + if i < 0 || i >= n { + ctx.ResultError(util.ErrorString("unzorder: index out of range")) + return + } + if z < 0 { + ctx.ResultError(util.ErrorString("unzorder: argument out of range")) + return + } + var k int var x int64 for j := i; j < 63; j += n { diff --git a/ext/zorder/zorder_test.go b/ext/zorder/zorder_test.go index 4dbad49..01537b1 100644 --- a/ext/zorder/zorder_test.go +++ b/ext/zorder/zorder_test.go @@ -12,7 +12,7 @@ import ( "github.com/ncruces/go-sqlite3/vfs/memdb" ) -func TestRegister_zorder(t *testing.T) { +func Test_zorder(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) @@ -57,7 +57,7 @@ func TestRegister_zorder(t *testing.T) { } } -func TestRegister_unzorder(t *testing.T) { +func Test_unzorder(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) @@ -85,7 +85,7 @@ func TestRegister_unzorder(t *testing.T) { } } -func TestRegister_error(t *testing.T) { +func Test_zorder_error(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) @@ -103,7 +103,7 @@ func TestRegister_error(t *testing.T) { var buf strings.Builder buf.WriteString("SELECT zorder(0") - for i := 1; i < 80; i++ { + for i := 1; i < 25; i++ { buf.WriteByte(',') buf.WriteString(strconv.Itoa(0)) } @@ -113,3 +113,30 @@ func TestRegister_error(t *testing.T) { t.Error("want error") } } + +func Test_unzorder_error(t *testing.T) { + t.Parallel() + tmp := memdb.TestDB(t) + + db, err := driver.Open(tmp, zorder.Register) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var got int64 + err = db.QueryRow(`SELECT unzorder(-1, 2, 0)`).Scan(&got) + if err == nil { + t.Error("want error") + } + + err = db.QueryRow(`SELECT unzorder(0, 2, 2)`).Scan(&got) + if err == nil { + t.Error("want error") + } + + err = db.QueryRow(`SELECT unzorder(0, 25, 2)`).Scan(&got) + if err == nil { + t.Error("want error") + } +}