Add strrchr.

This commit is contained in:
Nuno Cruces
2025-05-02 14:35:14 +01:00
parent d748d98e39
commit db7aacff9f
5 changed files with 278 additions and 61 deletions

View File

@@ -37,6 +37,7 @@ EOF
-Wl,--export=strcspn \
-Wl,--export=strlen \
-Wl,--export=strncmp \
-Wl,--export=strrchr \
-Wl,--export=strspn \
-Wl,--export=qsort

Binary file not shown.

View File

@@ -16,6 +16,7 @@
(export "strncmp" (func $strncmp))
(export "strchrnul" (func $strchrnul))
(export "strchr" (func $strchr))
(export "strrchr" (func $strrchr))
(export "strspn" (func $strspn))
(export "strcspn" (func $strcspn))
(export "qsort" (func $qsort))
@@ -910,6 +911,126 @@
)
)
)
(func $strrchr (param $0 i32) (param $1 i32) (result i32)
(local $2 i32)
(local $3 v128)
(local $4 v128)
(block $block1 (result i32)
(block $block
(br_if $block
(i32.lt_u
(local.tee $2
(i32.add
(call $strlen
(local.get $0)
)
(i32.const 1)
)
)
(i32.const 16)
)
)
(local.set $3
(i8x16.splat
(local.get $1)
)
)
(loop $label
(if
(i32.eqz
(v128.any_true
(local.tee $4
(i8x16.eq
(v128.load align=1
(i32.sub
(i32.add
(local.get $0)
(local.get $2)
)
(i32.const 16)
)
)
(local.get $3)
)
)
)
)
(then
(br_if $label
(i32.gt_u
(local.tee $2
(i32.sub
(local.get $2)
(i32.const 16)
)
)
(i32.const 15)
)
)
(br $block)
)
)
)
(br $block1
(i32.add
(i32.add
(i32.sub
(local.get $0)
(i32.clz
(i8x16.bitmask
(local.get $4)
)
)
)
(local.get $2)
)
(i32.const 15)
)
)
)
(local.set $0
(i32.add
(local.get $0)
(local.get $2)
)
)
(local.set $1
(i32.extend8_s
(local.get $1)
)
)
(loop $label1
(drop
(br_if $block1
(i32.const 0)
(i32.eqz
(local.get $2)
)
)
)
(local.set $2
(i32.sub
(local.get $2)
(i32.const 1)
)
)
(br_if $label1
(i32.ne
(local.get $1)
(i32.load8_s
(local.tee $0
(i32.sub
(local.get $0)
(i32.const 1)
)
)
)
)
)
)
(local.get $0)
)
)
(func $strspn (param $0 i32) (param $1 i32) (result i32)
(local $2 i32)
(local $3 i32)

View File

@@ -31,6 +31,7 @@ var (
strchr api.Function
strcmp api.Function
strspn api.Function
strrchr api.Function
strncmp api.Function
strcspn api.Function
stack [8]uint64
@@ -63,6 +64,7 @@ func TestMain(m *testing.M) {
strchr = mod.ExportedFunction("strchr")
strcmp = mod.ExportedFunction("strcmp")
strspn = mod.ExportedFunction("strspn")
strrchr = mod.ExportedFunction("strrchr")
strncmp = mod.ExportedFunction("strncmp")
strcspn = mod.ExportedFunction("strcspn")
memory, _ = mod.Memory().Read(0, mod.Memory().Size())
@@ -139,6 +141,18 @@ func Benchmark_strchr(b *testing.B) {
}
}
func Benchmark_strrchr(b *testing.B) {
clear(memory)
fill(memory[ptr1:ptr1+size/2], 5)
fill(memory[ptr1+size/2:ptr1+size-1], 7)
b.SetBytes(size/2 + 1)
b.ResetTimer()
for range b.N {
call(strrchr, ptr1, 5)
}
}
func Benchmark_strcmp(b *testing.B) {
clear(memory)
fill(memory[ptr1:ptr1+size-1], 7)
@@ -199,17 +213,16 @@ func Test_memchr(t *testing.T) {
for length := range 64 {
for pos := range length + 2 {
for alignment := range 24 {
clear(memory[:2*page])
ptr := (page - 8) + alignment
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
want := 0
if pos < length {
want = ptr + pos
}
clear(memory[:2*page])
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
got := call(memchr, uint64(ptr), 7, uint64(length))
if uint32(got) != uint32(want) {
t.Errorf("memchr(%d, %d, %d) = %d, want %d",
@@ -239,9 +252,9 @@ func Test_memchr(t *testing.T) {
func Test_strlen(t *testing.T) {
for length := range 64 {
for alignment := range 24 {
clear(memory[:2*page])
ptr := (page - 8) + alignment
clear(memory[:2*page])
fill(memory[ptr:ptr+length], 5)
got := call(strlen, uint64(ptr))
@@ -274,18 +287,18 @@ func Test_strchr(t *testing.T) {
for length := range 64 {
for pos := range length + 2 {
for alignment := range 24 {
clear(memory[:2*page])
ptr := (page - 8) + alignment
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
memory[ptr+length] = 0
want := 0
if pos < length {
want = ptr + pos
}
clear(memory[:2*page])
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
memory[ptr+pos+1] = 7
memory[ptr+length] = 0
got := call(strchr, uint64(ptr), 7)
if uint32(got) != uint32(want) {
t.Errorf("strchr(%d, %d) = %d, want %d",
@@ -312,21 +325,66 @@ func Test_strchr(t *testing.T) {
}
}
func Test_strrchr(t *testing.T) {
for length := range 64 {
for pos := range length + 2 {
for alignment := range 24 {
ptr := (page - 8) + alignment
want := 0
if pos < length {
want = ptr + pos
} else if length > 0 {
want = ptr
}
clear(memory[:2*page])
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr] = 7
memory[ptr+pos] = 7
memory[ptr+length] = 0
got := call(strrchr, uint64(ptr), 7)
if uint32(got) != uint32(want) {
t.Errorf("strrchr(%d, %d) = %d, want %d",
ptr, 7, uint32(got), uint32(want))
}
}
}
ptr := len(memory) - length
want := len(memory) - 2
if length <= 1 {
continue
}
clear(memory)
fill(memory[ptr:ptr+length], 5)
memory[ptr] = 7
memory[len(memory)-2] = 7
memory[len(memory)-1] = 0
got := call(strrchr, uint64(ptr), 7)
if uint32(got) != uint32(want) {
t.Errorf("strrchr(%d, %d) = %d, want %d",
ptr, 7, uint32(got), uint32(want))
}
}
}
func Test_strspn(t *testing.T) {
for length := range 64 {
for pos := range length + 2 {
for alignment := range 24 {
clear(memory[:2*page])
ptr := (page - 8) + alignment
want := min(pos, length)
clear(memory[:2*page])
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
memory[ptr+length] = 0
memory[128] = 3
memory[129] = 5
want := min(pos, length)
got := call(strspn, uint64(ptr), 129)
if uint32(got) != uint32(want) {
t.Errorf("strspn(%d, %d) = %d, want %d",
@@ -341,18 +399,18 @@ func Test_strspn(t *testing.T) {
}
}
clear(memory)
ptr := len(memory) - length
fill(memory[ptr:ptr+length], 5)
memory[len(memory)-1] = 7
memory[128] = 3
memory[129] = 5
want := length - 1
if length == 0 {
continue
}
clear(memory)
fill(memory[ptr:ptr+length], 5)
memory[len(memory)-1] = 7
memory[128] = 3
memory[129] = 5
got := call(strspn, uint64(ptr), 129)
if uint32(got) != uint32(want) {
t.Errorf("strspn(%d, %d) = %d, want %d",
@@ -371,17 +429,16 @@ func Test_strcspn(t *testing.T) {
for length := range 64 {
for pos := range length + 2 {
for alignment := range 24 {
clear(memory[:2*page])
ptr := (page - 8) + alignment
want := min(pos, length)
clear(memory[:2*page])
fill(memory[ptr:ptr+max(pos, length)], 5)
memory[ptr+pos] = 7
memory[ptr+length] = 0
memory[128] = 3
memory[129] = 7
want := min(pos, length)
got := call(strcspn, uint64(ptr), 129)
if uint32(got) != uint32(want) {
t.Errorf("strcspn(%d, %d) = %d, want %d",
@@ -396,18 +453,18 @@ func Test_strcspn(t *testing.T) {
}
}
clear(memory)
ptr := len(memory) - length
fill(memory[ptr:ptr+length], 5)
memory[len(memory)-1] = 7
memory[128] = 3
memory[129] = 7
want := length - 1
if length == 0 {
continue
}
clear(memory)
fill(memory[ptr:ptr+length], 5)
memory[len(memory)-1] = 7
memory[128] = 3
memory[129] = 7
got := call(strcspn, uint64(ptr), 129)
if uint32(got) != uint32(want) {
t.Errorf("strcspn(%d, %d) = %d, want %d",

View File

@@ -42,15 +42,17 @@ void *memmove(void *dest, const void *src, size_t n) {
__attribute__((weak))
int memcmp(const void *v1, const void *v2, size_t n) {
// memcmp can read up to n bytes from each object.
// Use unaligned loads to handle the case where
// the objects have mismatching alignments.
// memcmp is allowed to read up to n bytes from each object.
// Find the first different character in the objects.
// Unaligned loads handle the case where the objects
// have mismatching alignments.
const v128_t *w1 = (v128_t *)v1;
const v128_t *w2 = (v128_t *)v2;
for (; n >= sizeof(v128_t); n -= sizeof(v128_t)) {
const v128_t cmp = wasm_i8x16_eq(wasm_v128_load(w1), wasm_v128_load(w2));
// Bitmask is slow on AArch64, all_true is much faster.
if (!wasm_i8x16_all_true(cmp)) {
// Find the offset of the first zero bit (little-endian).
size_t ctz = __builtin_ctz(~wasm_i8x16_bitmask(cmp));
const unsigned char *u1 = (unsigned char *)w1 + ctz;
const unsigned char *u2 = (unsigned char *)w2 + ctz;
@@ -60,7 +62,7 @@ int memcmp(const void *v1, const void *v2, size_t n) {
w2++;
}
// Continue byte-by-byte.
// Baseline algorithm.
const unsigned char *u1 = (unsigned char *)w1;
const unsigned char *u2 = (unsigned char *)w2;
while (n--) {
@@ -74,7 +76,7 @@ int memcmp(const void *v1, const void *v2, size_t n) {
__attribute__((weak))
void *memchr(const void *v, int c, size_t n) {
// When n is zero, a function that locates a character finds no occurrence.
// Otherwise, decrement n to ensure __builtin_sub_overflow overflows
// Otherwise, decrement n to ensure sub_overflow overflows
// when n would go equal-to-or-below zero.
if (n-- == 0) {
return NULL;
@@ -82,16 +84,16 @@ void *memchr(const void *v, int c, size_t n) {
// memchr must behave as if it reads characters sequentially
// and stops as soon as a match is found.
// Aligning ensures loads can't fail.
// Aligning ensures loads beyond the first match don't fail.
uintptr_t align = (uintptr_t)v % sizeof(v128_t);
const v128_t *w = (v128_t *)((char *)v - align);
const v128_t wc = wasm_i8x16_splat(c);
while (true) {
for (;;) {
const v128_t cmp = wasm_i8x16_eq(*w, wc);
// Bitmask is slow on AArch64, any_true is much faster.
if (wasm_v128_any_true(cmp)) {
// Clear the bits corresponding to alignment
// Clear the bits corresponding to alignment (little-endian)
// so we can count trailing zeros.
int mask = wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless we cleared them.
@@ -115,18 +117,41 @@ void *memchr(const void *v, int c, size_t n) {
}
}
__attribute__((weak))
void *memrchr(const void *v, int c, size_t n) {
// memrchr is allowed to read up to n bytes from the object.
// Search backward for the last matching character.
const v128_t *w = (v128_t *)((char *)v + n);
const v128_t wc = wasm_i8x16_splat(c);
for (; n >= sizeof(v128_t); n -= sizeof(v128_t)) {
const v128_t cmp = wasm_i8x16_eq(wasm_v128_load(--w), wc);
// Bitmask is slow on AArch64, any_true is much faster.
if (wasm_v128_any_true(cmp)) {
size_t clz = __builtin_clz(wasm_i8x16_bitmask(cmp)) - 15;
return (char *)(w + 1) - clz;
}
}
// Baseline algorithm.
const char *a = (char *)w;
while (n--) {
if (*(--a) == (char)c) return (char *)a;
}
return NULL;
}
__attribute__((weak))
size_t strlen(const char *s) {
// strlen must stop as soon as it finds the terminator.
// Aligning ensures loads can't fail.
// Aligning ensures loads beyond the terminator don't fail.
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
const v128_t *w = (v128_t *)(s - align);
while (true) {
for (;;) {
// Bitmask is slow on AArch64, all_true is much faster.
if (!wasm_i8x16_all_true(*w)) {
const v128_t cmp = wasm_i8x16_eq(*w, (v128_t){});
// Clear the bits corresponding to alignment
// Clear the bits corresponding to alignment (little-endian)
// so we can count trailing zeros.
int mask = wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless we cleared them.
@@ -148,17 +173,19 @@ static int __strcmp(const char *s1, const char *s2) {
const v128_t *const limit =
(v128_t *)(__builtin_wasm_memory_size(0) * PAGESIZE) - 1;
// Use unaligned loads to handle the case where
// the strings have mismatching alignments.
// Unaligned loads handle the case where the strings
// have mismatching alignments.
const v128_t *w1 = (v128_t *)s1;
const v128_t *w2 = (v128_t *)s2;
while (w1 <= limit && w2 <= limit) {
// Find any single bit difference.
if (wasm_v128_any_true(wasm_v128_load(w1) ^ wasm_v128_load(w2))) {
// The strings may still be equal,
// if the terminator is found before that difference.
break;
}
// All bytes are equal.
// If any byte is zero (on both strings) the strings are equal.
// All characters are equal.
// If any is a terminator the strings are equal.
if (!wasm_i8x16_all_true(wasm_v128_load(w1))) {
return 0;
}
@@ -166,10 +193,10 @@ static int __strcmp(const char *s1, const char *s2) {
w2++;
}
// Continue byte-by-byte.
// Baseline algorithm.
const unsigned char *u1 = (unsigned char *)w1;
const unsigned char *u2 = (unsigned char *)w2;
while (true) {
for (;;) {
if (*u1 != *u2) return *u1 - *u2;
if (*u1 == 0) break;
u1++;
@@ -181,7 +208,7 @@ static int __strcmp(const char *s1, const char *s2) {
static int __strcmp_s(const char *s1, const char *s2) {
const unsigned char *u1 = (unsigned char *)s1;
const unsigned char *u2 = (unsigned char *)s2;
while (true) {
for (;;) {
if (*u1 != *u2) return *u1 - *u2;
if (*u1 == 0) break;
u1++;
@@ -207,17 +234,19 @@ int strncmp(const char *s1, const char *s2, size_t n) {
const v128_t *const limit =
(v128_t *)(__builtin_wasm_memory_size(0) * PAGESIZE) - 1;
// Use unaligned loads to handle the case where
// the strings have mismatching alignments.
// Unaligned loads handle the case where the strings
// have mismatching alignments.
const v128_t *w1 = (v128_t *)s1;
const v128_t *w2 = (v128_t *)s2;
for (; w1 <= limit && w2 <= limit && n >= sizeof(v128_t); n -= sizeof(v128_t)) {
// Find any single bit difference.
if (wasm_v128_any_true(wasm_v128_load(w1) ^ wasm_v128_load(w2))) {
// The strings may still be equal,
// if the terminator is found before that difference.
break;
}
// All bytes are equal.
// If any byte is zero (on both strings) the strings are equal.
// All characters are equal.
// If any is a terminator the strings are equal.
if (!wasm_i8x16_all_true(wasm_v128_load(w1))) {
return 0;
}
@@ -225,7 +254,7 @@ int strncmp(const char *s1, const char *s2, size_t n) {
w2++;
}
// Continue byte-by-byte.
// Baseline algorithm.
const unsigned char *u1 = (unsigned char *)w1;
const unsigned char *u2 = (unsigned char *)w2;
while (n--) {
@@ -239,16 +268,16 @@ int strncmp(const char *s1, const char *s2, size_t n) {
static char *__strchrnul(const char *s, int c) {
// strchrnul must stop as soon as a match is found.
// Aligning ensures loads can't fail.
// Aligning ensures loads beyond the first match don't fail.
uintptr_t align = (uintptr_t)s % sizeof(v128_t);
const v128_t *w = (v128_t *)(s - align);
const v128_t wc = wasm_i8x16_splat(c);
while (true) {
for (;;) {
const v128_t cmp = wasm_i8x16_eq(*w, (v128_t){}) | wasm_i8x16_eq(*w, wc);
// Bitmask is slow on AArch64, any_true is much faster.
if (wasm_v128_any_true(cmp)) {
// Clear the bits corresponding to alignment
// Clear the bits corresponding to alignment (little-endian)
// so we can count trailing zeros.
int mask = wasm_i8x16_bitmask(cmp) >> align << align;
// At least one bit will be set, unless we cleared them.
@@ -279,7 +308,16 @@ char *strchr(const char *s, int c) {
return (char *)s + strlen(s);
}
char *r = __strchrnul(s, c);
return *(char *)r == (char)c ? r : NULL;
return *r == (char)c ? r : NULL;
}
__attribute__((weak, always_inline))
char *strrchr(const char *s, int c) {
// For finding the terminator, strlen is faster.
if (__builtin_constant_p(c) && (char)c == 0) {
return (char *)s + strlen(s);
}
return (char *)memrchr(s, c, strlen(s) + 1);
}
__attribute__((weak))
@@ -310,7 +348,7 @@ size_t strspn(const char *s, const char *c) {
w++;
}
// Continue byte-by-byte.
// Baseline algorithm.
s = (char *)w;
while (*s == *c) s++;
return s - a;