diff --git a/sqlite3/libc/build.sh b/sqlite3/libc/build.sh index 222fb5e..fd4661b 100755 --- a/sqlite3/libc/build.sh +++ b/sqlite3/libc/build.sh @@ -37,6 +37,7 @@ EOF -Wl,--export=strcspn \ -Wl,--export=strlen \ -Wl,--export=strncmp \ + -Wl,--export=strrchr \ -Wl,--export=strspn \ -Wl,--export=qsort diff --git a/sqlite3/libc/libc.wasm b/sqlite3/libc/libc.wasm index 42dbc39..1efb229 100755 Binary files a/sqlite3/libc/libc.wasm and b/sqlite3/libc/libc.wasm differ diff --git a/sqlite3/libc/libc.wat b/sqlite3/libc/libc.wat index b11ea9f..e8e5d8f 100644 --- a/sqlite3/libc/libc.wat +++ b/sqlite3/libc/libc.wat @@ -16,6 +16,7 @@ (export "strncmp" (func $strncmp)) (export "strchrnul" (func $strchrnul)) (export "strchr" (func $strchr)) + (export "strrchr" (func $strrchr)) (export "strspn" (func $strspn)) (export "strcspn" (func $strcspn)) (export "qsort" (func $qsort)) @@ -910,6 +911,126 @@ ) ) ) + (func $strrchr (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 v128) + (local $4 v128) + (block $block1 (result i32) + (block $block + (br_if $block + (i32.lt_u + (local.tee $2 + (i32.add + (call $strlen + (local.get $0) + ) + (i32.const 1) + ) + ) + (i32.const 16) + ) + ) + (local.set $3 + (i8x16.splat + (local.get $1) + ) + ) + (loop $label + (if + (i32.eqz + (v128.any_true + (local.tee $4 + (i8x16.eq + (v128.load align=1 + (i32.sub + (i32.add + (local.get $0) + (local.get $2) + ) + (i32.const 16) + ) + ) + (local.get $3) + ) + ) + ) + ) + (then + (br_if $label + (i32.gt_u + (local.tee $2 + (i32.sub + (local.get $2) + (i32.const 16) + ) + ) + (i32.const 15) + ) + ) + (br $block) + ) + ) + ) + (br $block1 + (i32.add + (i32.add + (i32.sub + (local.get $0) + (i32.clz + (i8x16.bitmask + (local.get $4) + ) + ) + ) + (local.get $2) + ) + (i32.const 15) + ) + ) + ) + (local.set $0 + (i32.add + (local.get $0) + (local.get $2) + ) + ) + (local.set $1 + (i32.extend8_s + (local.get $1) + ) + ) + (loop $label1 + (drop + (br_if $block1 + (i32.const 0) + (i32.eqz + (local.get $2) + ) + ) + ) + (local.set $2 + (i32.sub + (local.get $2) + (i32.const 1) + ) + ) + (br_if $label1 + (i32.ne + (local.get $1) + (i32.load8_s + (local.tee $0 + (i32.sub + (local.get $0) + (i32.const 1) + ) + ) + ) + ) + ) + ) + (local.get $0) + ) + ) (func $strspn (param $0 i32) (param $1 i32) (result i32) (local $2 i32) (local $3 i32) diff --git a/sqlite3/libc/libc_test.go b/sqlite3/libc/libc_test.go index 4366e5d..b8b856c 100644 --- a/sqlite3/libc/libc_test.go +++ b/sqlite3/libc/libc_test.go @@ -31,6 +31,7 @@ var ( strchr api.Function strcmp api.Function strspn api.Function + strrchr api.Function strncmp api.Function strcspn api.Function stack [8]uint64 @@ -63,6 +64,7 @@ func TestMain(m *testing.M) { strchr = mod.ExportedFunction("strchr") strcmp = mod.ExportedFunction("strcmp") strspn = mod.ExportedFunction("strspn") + strrchr = mod.ExportedFunction("strrchr") strncmp = mod.ExportedFunction("strncmp") strcspn = mod.ExportedFunction("strcspn") memory, _ = mod.Memory().Read(0, mod.Memory().Size()) @@ -139,6 +141,18 @@ func Benchmark_strchr(b *testing.B) { } } +func Benchmark_strrchr(b *testing.B) { + clear(memory) + fill(memory[ptr1:ptr1+size/2], 5) + fill(memory[ptr1+size/2:ptr1+size-1], 7) + + b.SetBytes(size/2 + 1) + b.ResetTimer() + for range b.N { + call(strrchr, ptr1, 5) + } +} + func Benchmark_strcmp(b *testing.B) { clear(memory) fill(memory[ptr1:ptr1+size-1], 7) @@ -199,17 +213,16 @@ func Test_memchr(t *testing.T) { for length := range 64 { for pos := range length + 2 { for alignment := range 24 { - clear(memory[:2*page]) - ptr := (page - 8) + alignment - fill(memory[ptr:ptr+max(pos, length)], 5) - memory[ptr+pos] = 7 - want := 0 if pos < length { want = ptr + pos } + clear(memory[:2*page]) + fill(memory[ptr:ptr+max(pos, length)], 5) + memory[ptr+pos] = 7 + got := call(memchr, uint64(ptr), 7, uint64(length)) if uint32(got) != uint32(want) { t.Errorf("memchr(%d, %d, %d) = %d, want %d", @@ -239,9 +252,9 @@ func Test_memchr(t *testing.T) { func Test_strlen(t *testing.T) { for length := range 64 { for alignment := range 24 { - clear(memory[:2*page]) - ptr := (page - 8) + alignment + + clear(memory[:2*page]) fill(memory[ptr:ptr+length], 5) got := call(strlen, uint64(ptr)) @@ -274,18 +287,18 @@ func Test_strchr(t *testing.T) { for length := range 64 { for pos := range length + 2 { for alignment := range 24 { - clear(memory[:2*page]) - ptr := (page - 8) + alignment - fill(memory[ptr:ptr+max(pos, length)], 5) - memory[ptr+pos] = 7 - memory[ptr+length] = 0 - want := 0 if pos < length { want = ptr + pos } + clear(memory[:2*page]) + fill(memory[ptr:ptr+max(pos, length)], 5) + memory[ptr+pos] = 7 + memory[ptr+pos+1] = 7 + memory[ptr+length] = 0 + got := call(strchr, uint64(ptr), 7) if uint32(got) != uint32(want) { t.Errorf("strchr(%d, %d) = %d, want %d", @@ -312,21 +325,66 @@ func Test_strchr(t *testing.T) { } } +func Test_strrchr(t *testing.T) { + for length := range 64 { + for pos := range length + 2 { + for alignment := range 24 { + ptr := (page - 8) + alignment + want := 0 + if pos < length { + want = ptr + pos + } else if length > 0 { + want = ptr + } + + clear(memory[:2*page]) + fill(memory[ptr:ptr+max(pos, length)], 5) + memory[ptr] = 7 + memory[ptr+pos] = 7 + memory[ptr+length] = 0 + + got := call(strrchr, uint64(ptr), 7) + if uint32(got) != uint32(want) { + t.Errorf("strrchr(%d, %d) = %d, want %d", + ptr, 7, uint32(got), uint32(want)) + } + } + } + + ptr := len(memory) - length + want := len(memory) - 2 + if length <= 1 { + continue + } + + clear(memory) + fill(memory[ptr:ptr+length], 5) + memory[ptr] = 7 + memory[len(memory)-2] = 7 + memory[len(memory)-1] = 0 + + got := call(strrchr, uint64(ptr), 7) + if uint32(got) != uint32(want) { + t.Errorf("strrchr(%d, %d) = %d, want %d", + ptr, 7, uint32(got), uint32(want)) + } + } +} + func Test_strspn(t *testing.T) { for length := range 64 { for pos := range length + 2 { for alignment := range 24 { - clear(memory[:2*page]) - ptr := (page - 8) + alignment + want := min(pos, length) + + clear(memory[:2*page]) fill(memory[ptr:ptr+max(pos, length)], 5) memory[ptr+pos] = 7 memory[ptr+length] = 0 memory[128] = 3 memory[129] = 5 - want := min(pos, length) - got := call(strspn, uint64(ptr), 129) if uint32(got) != uint32(want) { t.Errorf("strspn(%d, %d) = %d, want %d", @@ -341,18 +399,18 @@ func Test_strspn(t *testing.T) { } } - clear(memory) ptr := len(memory) - length - fill(memory[ptr:ptr+length], 5) - memory[len(memory)-1] = 7 - memory[128] = 3 - memory[129] = 5 - want := length - 1 if length == 0 { continue } + clear(memory) + fill(memory[ptr:ptr+length], 5) + memory[len(memory)-1] = 7 + memory[128] = 3 + memory[129] = 5 + got := call(strspn, uint64(ptr), 129) if uint32(got) != uint32(want) { t.Errorf("strspn(%d, %d) = %d, want %d", @@ -371,17 +429,16 @@ func Test_strcspn(t *testing.T) { for length := range 64 { for pos := range length + 2 { for alignment := range 24 { - clear(memory[:2*page]) - ptr := (page - 8) + alignment + want := min(pos, length) + + clear(memory[:2*page]) fill(memory[ptr:ptr+max(pos, length)], 5) memory[ptr+pos] = 7 memory[ptr+length] = 0 memory[128] = 3 memory[129] = 7 - want := min(pos, length) - got := call(strcspn, uint64(ptr), 129) if uint32(got) != uint32(want) { t.Errorf("strcspn(%d, %d) = %d, want %d", @@ -396,18 +453,18 @@ func Test_strcspn(t *testing.T) { } } - clear(memory) ptr := len(memory) - length - fill(memory[ptr:ptr+length], 5) - memory[len(memory)-1] = 7 - memory[128] = 3 - memory[129] = 7 - want := length - 1 if length == 0 { continue } + clear(memory) + fill(memory[ptr:ptr+length], 5) + memory[len(memory)-1] = 7 + memory[128] = 3 + memory[129] = 7 + got := call(strcspn, uint64(ptr), 129) if uint32(got) != uint32(want) { t.Errorf("strcspn(%d, %d) = %d, want %d", diff --git a/sqlite3/libc/string.h b/sqlite3/libc/string.h index b56703f..7348c34 100644 --- a/sqlite3/libc/string.h +++ b/sqlite3/libc/string.h @@ -42,15 +42,17 @@ void *memmove(void *dest, const void *src, size_t n) { __attribute__((weak)) int memcmp(const void *v1, const void *v2, size_t n) { - // memcmp can read up to n bytes from each object. - // Use unaligned loads to handle the case where - // the objects have mismatching alignments. + // memcmp is allowed to read up to n bytes from each object. + // Find the first different character in the objects. + // Unaligned loads handle the case where the objects + // have mismatching alignments. const v128_t *w1 = (v128_t *)v1; const v128_t *w2 = (v128_t *)v2; for (; n >= sizeof(v128_t); n -= sizeof(v128_t)) { const v128_t cmp = wasm_i8x16_eq(wasm_v128_load(w1), wasm_v128_load(w2)); // Bitmask is slow on AArch64, all_true is much faster. if (!wasm_i8x16_all_true(cmp)) { + // Find the offset of the first zero bit (little-endian). size_t ctz = __builtin_ctz(~wasm_i8x16_bitmask(cmp)); const unsigned char *u1 = (unsigned char *)w1 + ctz; const unsigned char *u2 = (unsigned char *)w2 + ctz; @@ -60,7 +62,7 @@ int memcmp(const void *v1, const void *v2, size_t n) { w2++; } - // Continue byte-by-byte. + // Baseline algorithm. const unsigned char *u1 = (unsigned char *)w1; const unsigned char *u2 = (unsigned char *)w2; while (n--) { @@ -74,7 +76,7 @@ int memcmp(const void *v1, const void *v2, size_t n) { __attribute__((weak)) void *memchr(const void *v, int c, size_t n) { // When n is zero, a function that locates a character finds no occurrence. - // Otherwise, decrement n to ensure __builtin_sub_overflow overflows + // Otherwise, decrement n to ensure sub_overflow overflows // when n would go equal-to-or-below zero. if (n-- == 0) { return NULL; @@ -82,16 +84,16 @@ void *memchr(const void *v, int c, size_t n) { // memchr must behave as if it reads characters sequentially // and stops as soon as a match is found. - // Aligning ensures loads can't fail. + // Aligning ensures loads beyond the first match don't fail. uintptr_t align = (uintptr_t)v % sizeof(v128_t); const v128_t *w = (v128_t *)((char *)v - align); const v128_t wc = wasm_i8x16_splat(c); - while (true) { + for (;;) { const v128_t cmp = wasm_i8x16_eq(*w, wc); // Bitmask is slow on AArch64, any_true is much faster. if (wasm_v128_any_true(cmp)) { - // Clear the bits corresponding to alignment + // Clear the bits corresponding to alignment (little-endian) // so we can count trailing zeros. int mask = wasm_i8x16_bitmask(cmp) >> align << align; // At least one bit will be set, unless we cleared them. @@ -115,18 +117,41 @@ void *memchr(const void *v, int c, size_t n) { } } +__attribute__((weak)) +void *memrchr(const void *v, int c, size_t n) { + // memrchr is allowed to read up to n bytes from the object. + // Search backward for the last matching character. + const v128_t *w = (v128_t *)((char *)v + n); + const v128_t wc = wasm_i8x16_splat(c); + for (; n >= sizeof(v128_t); n -= sizeof(v128_t)) { + const v128_t cmp = wasm_i8x16_eq(wasm_v128_load(--w), wc); + // Bitmask is slow on AArch64, any_true is much faster. + if (wasm_v128_any_true(cmp)) { + size_t clz = __builtin_clz(wasm_i8x16_bitmask(cmp)) - 15; + return (char *)(w + 1) - clz; + } + } + + // Baseline algorithm. + const char *a = (char *)w; + while (n--) { + if (*(--a) == (char)c) return (char *)a; + } + return NULL; +} + __attribute__((weak)) size_t strlen(const char *s) { // strlen must stop as soon as it finds the terminator. - // Aligning ensures loads can't fail. + // Aligning ensures loads beyond the terminator don't fail. uintptr_t align = (uintptr_t)s % sizeof(v128_t); const v128_t *w = (v128_t *)(s - align); - while (true) { + for (;;) { // Bitmask is slow on AArch64, all_true is much faster. if (!wasm_i8x16_all_true(*w)) { const v128_t cmp = wasm_i8x16_eq(*w, (v128_t){}); - // Clear the bits corresponding to alignment + // Clear the bits corresponding to alignment (little-endian) // so we can count trailing zeros. int mask = wasm_i8x16_bitmask(cmp) >> align << align; // At least one bit will be set, unless we cleared them. @@ -148,17 +173,19 @@ static int __strcmp(const char *s1, const char *s2) { const v128_t *const limit = (v128_t *)(__builtin_wasm_memory_size(0) * PAGESIZE) - 1; - // Use unaligned loads to handle the case where - // the strings have mismatching alignments. + // Unaligned loads handle the case where the strings + // have mismatching alignments. const v128_t *w1 = (v128_t *)s1; const v128_t *w2 = (v128_t *)s2; while (w1 <= limit && w2 <= limit) { // Find any single bit difference. if (wasm_v128_any_true(wasm_v128_load(w1) ^ wasm_v128_load(w2))) { + // The strings may still be equal, + // if the terminator is found before that difference. break; } - // All bytes are equal. - // If any byte is zero (on both strings) the strings are equal. + // All characters are equal. + // If any is a terminator the strings are equal. if (!wasm_i8x16_all_true(wasm_v128_load(w1))) { return 0; } @@ -166,10 +193,10 @@ static int __strcmp(const char *s1, const char *s2) { w2++; } - // Continue byte-by-byte. + // Baseline algorithm. const unsigned char *u1 = (unsigned char *)w1; const unsigned char *u2 = (unsigned char *)w2; - while (true) { + for (;;) { if (*u1 != *u2) return *u1 - *u2; if (*u1 == 0) break; u1++; @@ -181,7 +208,7 @@ static int __strcmp(const char *s1, const char *s2) { static int __strcmp_s(const char *s1, const char *s2) { const unsigned char *u1 = (unsigned char *)s1; const unsigned char *u2 = (unsigned char *)s2; - while (true) { + for (;;) { if (*u1 != *u2) return *u1 - *u2; if (*u1 == 0) break; u1++; @@ -207,17 +234,19 @@ int strncmp(const char *s1, const char *s2, size_t n) { const v128_t *const limit = (v128_t *)(__builtin_wasm_memory_size(0) * PAGESIZE) - 1; - // Use unaligned loads to handle the case where - // the strings have mismatching alignments. + // Unaligned loads handle the case where the strings + // have mismatching alignments. const v128_t *w1 = (v128_t *)s1; const v128_t *w2 = (v128_t *)s2; for (; w1 <= limit && w2 <= limit && n >= sizeof(v128_t); n -= sizeof(v128_t)) { // Find any single bit difference. if (wasm_v128_any_true(wasm_v128_load(w1) ^ wasm_v128_load(w2))) { + // The strings may still be equal, + // if the terminator is found before that difference. break; } - // All bytes are equal. - // If any byte is zero (on both strings) the strings are equal. + // All characters are equal. + // If any is a terminator the strings are equal. if (!wasm_i8x16_all_true(wasm_v128_load(w1))) { return 0; } @@ -225,7 +254,7 @@ int strncmp(const char *s1, const char *s2, size_t n) { w2++; } - // Continue byte-by-byte. + // Baseline algorithm. const unsigned char *u1 = (unsigned char *)w1; const unsigned char *u2 = (unsigned char *)w2; while (n--) { @@ -239,16 +268,16 @@ int strncmp(const char *s1, const char *s2, size_t n) { static char *__strchrnul(const char *s, int c) { // strchrnul must stop as soon as a match is found. - // Aligning ensures loads can't fail. + // Aligning ensures loads beyond the first match don't fail. uintptr_t align = (uintptr_t)s % sizeof(v128_t); const v128_t *w = (v128_t *)(s - align); const v128_t wc = wasm_i8x16_splat(c); - while (true) { + for (;;) { const v128_t cmp = wasm_i8x16_eq(*w, (v128_t){}) | wasm_i8x16_eq(*w, wc); // Bitmask is slow on AArch64, any_true is much faster. if (wasm_v128_any_true(cmp)) { - // Clear the bits corresponding to alignment + // Clear the bits corresponding to alignment (little-endian) // so we can count trailing zeros. int mask = wasm_i8x16_bitmask(cmp) >> align << align; // At least one bit will be set, unless we cleared them. @@ -279,7 +308,16 @@ char *strchr(const char *s, int c) { return (char *)s + strlen(s); } char *r = __strchrnul(s, c); - return *(char *)r == (char)c ? r : NULL; + return *r == (char)c ? r : NULL; +} + +__attribute__((weak, always_inline)) +char *strrchr(const char *s, int c) { + // For finding the terminator, strlen is faster. + if (__builtin_constant_p(c) && (char)c == 0) { + return (char *)s + strlen(s); + } + return (char *)memrchr(s, c, strlen(s) + 1); } __attribute__((weak)) @@ -310,7 +348,7 @@ size_t strspn(const char *s, const char *c) { w++; } - // Continue byte-by-byte. + // Baseline algorithm. s = (char *)w; while (*s == *c) s++; return s - a;