diff --git a/ext/bloom/bloom.go b/ext/bloom/bloom.go index 174aef1..1d1c32e 100644 --- a/ext/bloom/bloom.go +++ b/ext/bloom/bloom.go @@ -30,7 +30,7 @@ type bloom struct { schema string storage string prob float64 - nfilter int64 + bytes int64 hashes int } @@ -41,15 +41,17 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, storage: table + "_storage", } - nelem := 100 + var nelem int64 if len(arg) > 0 { - nelem, err = strconv.Atoi(arg[0]) + nelem, err = strconv.ParseInt(arg[0], 10, 64) if err != nil { return nil, err } if nelem <= 0 { return nil, errors.New("bloom: number of elements in filter must be positive") } + } else { + nelem = 100 } if len(arg) > 1 { @@ -73,10 +75,10 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, return nil, errors.New("bloom: number of hash functions must be positive") } } else { - t.hashes = int(math.Round(-math.Log2(t.prob))) + t.hashes = max(1, numHashes(t.prob)) } - t.nfilter = computeLength(nelem, t.prob) + t.bytes = numBytes(nelem, t.prob) err = db.Exec(fmt.Sprintf( `CREATE TABLE %s.%s (data BLOB, p REAL, n INTEGER, m INTEGER, k INTEGER)`, @@ -89,7 +91,7 @@ func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, `INSERT INTO %s.%s (rowid, data, p, n, m, k) VALUES (1, zeroblob(%d), %f, %d, %d, %d)`, sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage), - t.nfilter, t.prob, nelem, t.nfilter*8, t.hashes)) + t.bytes, t.prob, nelem, 8*t.bytes, t.hashes)) if err != nil { return nil, err } @@ -131,7 +133,7 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom return nil, err } - t.nfilter = load.ColumnInt64(0) + t.bytes = load.ColumnInt64(0) t.prob = load.ColumnFloat(1) t.hashes = load.ColumnInt(2) return &t, nil @@ -188,7 +190,7 @@ func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) { for n := 0; n < b.hashes; n++ { hash := calcHash(n, blob) - hash %= uint64(b.nfilter * 8) + hash %= uint64(b.bytes * 8) bitpos := byte(hash % 8) bytepos := int64(hash / 8) @@ -202,7 +204,7 @@ func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) { return 0, err } - buf[0] |= (1 << bitpos) + buf[0] |= 1 << bitpos _, err = f.Seek(bytepos, io.SeekStart) if err != nil { @@ -241,9 +243,9 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { } defer f.Close() - for n := 0; n < c.hashes; n++ { + for n := 0; n < c.hashes && !c.eof; n++ { hash := calcHash(n, blob) - hash %= uint64(c.nfilter * 8) + hash %= uint64(c.bytes * 8) bitpos := byte(hash % 8) bytepos := int64(hash / 8) @@ -257,10 +259,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { return err } - c.eof = (buf[0] & (1 << bitpos)) == 0 - if c.eof { - break - } + c.eof = buf[0]&(1<