More SIMD.

This commit is contained in:
Nuno Cruces
2025-04-05 02:03:31 +01:00
parent 4c19387535
commit 39f3fa64eb
5 changed files with 363 additions and 170 deletions

View File

@@ -20,6 +20,7 @@ trap 'rm -f libc.tmp' EXIT
-Wl,--export=memset \
-Wl,--export=memcpy \
-Wl,--export=memcmp \
-Wl,--export=strcmp \
-Wl,--export=strncmp
"$BINARYEN/wasm-ctor-eval" -g -c _initialize libc.wasm -o libc.tmp

Binary file not shown.

View File

@@ -1,11 +1,13 @@
(module $libc.wasm
(type $0 (func (param i32 i32 i32) (result i32)))
(type $1 (func (param i32 i32) (result i32)))
(memory $0 256)
(data $0 (i32.const 1024) "\01")
(export "memory" (memory $0))
(export "memset" (func $memset))
(export "memcpy" (func $memcpy))
(export "memcmp" (func $memcmp))
(export "strcmp" (func $strcmp))
(export "strncmp" (func $strncmp))
(func $memset (param $0 i32) (param $1 i32) (param $2 i32) (result i32)
(memory.fill
@@ -126,108 +128,130 @@
)
(i32.const 0)
)
(func $strncmp (param $0 i32) (param $1 i32) (param $2 i32) (result i32)
(func $strcmp (param $0 i32) (param $1 i32) (result i32)
(local $2 i32)
(local $3 i32)
(local $4 i32)
(local $5 i32)
(local $6 v128)
(block $block2
(block $block1
(block $block
(if
(i32.ge_u
(local.get $2)
(i32.const 16)
)
(then
(loop $label
(br_if $block
(v128.any_true
(v128.xor
(v128.load align=1
(local.get $1)
)
(local.tee $6
(v128.load align=1
(local.get $0)
)
)
)
)
)
(if
(i32.eqz
(i8x16.all_true
(local.get $6)
)
)
(then
(return
(i32.const 0)
)
)
)
(local.set $1
(i32.add
(local.get $1)
(i32.const 16)
)
)
(local.set $0
(i32.add
(local.get $0)
(i32.const 16)
)
)
(br_if $label
(i32.gt_u
(local.tee $2
(i32.sub
(local.get $2)
(i32.const 16)
)
)
(i32.const 15)
)
)
)
)
)
(br_if $block
(local.get $2)
)
(local.set $4
(i32.load8_u
(local $4 v128)
(local $5 v128)
(local.set $3
(block $block (result i32)
(if
(i32.and
(i32.or
(local.get $0)
(local.get $1)
)
(i32.const 15)
)
(local.set $3
(i32.load8_u
(local.get $0)
)
)
(br $block1)
)
(br_if $block1
(i32.ne
(local.tee $3
(then
(local.set $2
(i32.load8_u
(local.get $0)
)
)
(local.tee $4
(br $block
(i32.load8_u
(local.get $1)
)
)
)
)
(local.set $5
(i32.add
(local.get $1)
(i32.const 1)
(if
(v128.any_true
(v128.xor
(local.tee $5
(v128.load
(local.get $1)
)
)
(local.tee $4
(v128.load
(local.get $0)
)
)
)
)
(then
(local.set $2
(i8x16.extract_lane_u 0
(local.get $4)
)
)
(br $block
(i8x16.extract_lane_u 0
(local.get $5)
)
)
)
)
(loop $label
(if
(i32.eqz
(i8x16.all_true
(local.get $4)
)
)
(then
(return
(i32.const 0)
)
)
)
(local.set $4
(v128.load offset=16
(local.get $0)
)
)
(local.set $5
(v128.load offset=16
(local.get $1)
)
)
(local.set $1
(i32.add
(local.get $1)
(i32.const 16)
)
)
(local.set $0
(i32.add
(local.get $0)
(i32.const 16)
)
)
(br_if $label
(i32.eqz
(v128.any_true
(v128.xor
(local.get $5)
(local.get $4)
)
)
)
)
)
(local.set $2
(i8x16.extract_lane_u 0
(local.get $4)
)
)
(i8x16.extract_lane_u 0
(local.get $5)
)
)
)
(if
(i32.eq
(i32.and
(local.get $2)
(i32.const 255)
)
(i32.and
(local.get $3)
(i32.const 255)
)
)
(then
(local.set $0
(i32.add
(local.get $0)
@@ -235,44 +259,184 @@
)
)
(local.set $1
(i32.sub
(local.get $2)
(i32.add
(local.get $1)
(i32.const 1)
)
)
(local.set $2
(local.get $3)
)
(loop $label1
(local.set $2
(i32.const 0)
)
(br_if $block2
(if
(i32.eqz
(local.get $1)
(i32.and
(local.get $2)
(i32.const 255)
)
)
)
(br_if $block2
(i32.eqz
(local.get $3)
)
)
(local.set $1
(i32.sub
(local.get $1)
(i32.const 1)
)
)
(local.set $4
(i32.load8_u
(local.get $5)
(then
(return
(i32.const 0)
)
)
)
(local.set $3
(i32.load8_u
(local.get $1)
)
)
(local.set $2
(i32.load8_u
(local.get $0)
)
)
(local.set $5
(local.set $0
(i32.add
(local.get $5)
(local.get $0)
(i32.const 1)
)
)
(local.set $1
(i32.add
(local.get $1)
(i32.const 1)
)
)
(br_if $label1
(i32.eq
(local.get $2)
(local.get $3)
)
)
)
)
)
(i32.sub
(i32.and
(local.get $2)
(i32.const 255)
)
(i32.and
(local.get $3)
(i32.const 255)
)
)
)
(func $strncmp (param $0 i32) (param $1 i32) (param $2 i32) (result i32)
(local $3 i32)
(local $4 i32)
(local $5 v128)
(block $block
(if
(i32.ge_u
(local.get $2)
(i32.const 16)
)
(then
(loop $label
(br_if $block
(v128.any_true
(v128.xor
(v128.load align=1
(local.get $1)
)
(local.tee $5
(v128.load align=1
(local.get $0)
)
)
)
)
)
(if
(i32.eqz
(i8x16.all_true
(local.get $5)
)
)
(then
(return
(i32.const 0)
)
)
)
(local.set $1
(i32.add
(local.get $1)
(i32.const 16)
)
)
(local.set $0
(i32.add
(local.get $0)
(i32.const 16)
)
)
(br_if $label
(i32.gt_u
(local.tee $2
(i32.sub
(local.get $2)
(i32.const 16)
)
)
(i32.const 15)
)
)
)
)
)
(br_if $block
(local.get $2)
)
(return
(i32.const 0)
)
)
(local.set $2
(i32.sub
(local.get $2)
(i32.const 1)
)
)
(loop $label1
(if
(i32.ne
(local.tee $3
(i32.load8_u
(local.get $0)
)
)
(local.tee $4
(i32.load8_u
(local.get $1)
)
)
)
(then
(return
(i32.sub
(local.get $3)
(local.get $4)
)
)
)
)
(if
(local.get $3)
(then
(local.set $2
(i32.sub
(local.tee $3
(local.get $2)
)
(i32.const 1)
)
)
(local.set $1
(i32.add
(local.get $1)
(i32.const 1)
)
)
@@ -283,21 +447,12 @@
)
)
(br_if $label1
(i32.eq
(local.get $3)
(local.get $4)
)
(local.get $3)
)
)
)
(local.set $2
(i32.sub
(local.get $3)
(local.get $4)
)
)
)
(local.get $2)
(i32.const 0)
)
;; features section: mutable-globals, nontrapping-float-to-int, simd, bulk-memory, sign-ext, reference-types, multivalue, bulk-memory-opt
)

View File

@@ -25,12 +25,15 @@ var (
memset api.Function
memcpy api.Function
memcmp api.Function
strcmp api.Function
strncmp api.Function
stack [8]uint64
)
func call(fn api.Function, arg ...uint64) uint32 {
fn.CallWithStack(context.Background(), arg)
return uint32(arg[0])
copy(stack[:], arg)
fn.CallWithStack(context.Background(), stack[:])
return uint32(stack[0])
}
func TestMain(m *testing.M) {
@@ -46,6 +49,7 @@ func TestMain(m *testing.M) {
memset = mod.ExportedFunction("memset")
memcpy = mod.ExportedFunction("memcpy")
memcmp = mod.ExportedFunction("memcmp")
strcmp = mod.ExportedFunction("strcmp")
strncmp = mod.ExportedFunction("strncmp")
memory, _ = mod.Memory().Read(0, mod.Memory().Size())
@@ -55,6 +59,7 @@ func TestMain(m *testing.M) {
func Benchmark_memset(b *testing.B) {
clear(memory)
b.SetBytes(size)
b.ResetTimer()
for range b.N {
call(memset, ptr1, 3, size)
@@ -72,6 +77,7 @@ func Benchmark_memcpy(b *testing.B) {
clear(memory)
call(memset, ptr2, 5, size)
b.SetBytes(size)
b.ResetTimer()
for range b.N {
call(memcpy, ptr1, ptr2, size)
@@ -91,6 +97,7 @@ func Benchmark_memcmp(b *testing.B) {
call(memset, ptr2, 7, size)
call(memset, ptr2+size/2, 5, size)
b.SetBytes(size / 2)
b.ResetTimer()
for range b.N {
call(memcmp, ptr1, ptr2, size)
@@ -107,12 +114,42 @@ func Benchmark_memcmp(b *testing.B) {
}
}
func Benchmark_strcmp(b *testing.B) {
clear(memory)
call(memset, ptr1, 7, size-1)
call(memset, ptr2, 7, size-1)
call(memset, ptr2+size/2, 5, size)
b.SetBytes(size / 2)
b.ResetTimer()
for range b.N {
call(strcmp, ptr1, ptr2, size)
}
b.StopTimer()
// ptr1 > ptr2
if got := int32(call(strcmp, ptr1, ptr2)); got <= 0 {
b.Fatal(got)
}
// make ptr1 < ptr2
memory[ptr1+size/2] = 0
if got := int32(call(strcmp, ptr1, ptr2)); got >= 0 {
b.Fatal(got)
}
memory[ptr2+size/2] = 0
// make ptr1 == ptr2
if got := int32(call(strcmp, ptr1, ptr2)); got != 0 {
b.Fatal(got)
}
}
func Benchmark_strncmp(b *testing.B) {
clear(memory)
call(memset, ptr1, 7, size)
call(memset, ptr2, 7, size)
call(memset, ptr2+size/2, 5, size)
b.SetBytes(size / 2)
b.ResetTimer()
for range b.N {
call(strncmp, ptr1, ptr2, size)
@@ -120,16 +157,16 @@ func Benchmark_strncmp(b *testing.B) {
b.StopTimer()
// ptr1 > ptr2
if got := int32(call(memcmp, ptr1, ptr2, size)); got <= 0 {
if got := int32(call(strncmp, ptr1, ptr2, size)); got <= 0 {
b.Fatal(got)
}
// make ptr1 < ptr2
memory[ptr1+size/2] = 0
if got := int32(call(memcmp, ptr1, ptr2, size)); got >= 0 {
if got := int32(call(strncmp, ptr1, ptr2, size)); got >= 0 {
b.Fatal(got)
}
// ptr1[:size/2] == ptr2[:size/2]
if got := int32(call(memcmp, ptr1, ptr2, size/2)); got != 0 {
if got := int32(call(strncmp, ptr1, ptr2, size/2)); got != 0 {
b.Fatal(got)
}
}

View File

@@ -21,6 +21,8 @@ void *memmove(void *dest, const void *src, size_t n) {
#ifdef __wasm_simd128__
#define UNALIGNED(x) ((uintptr_t)x % sizeof(*x))
int memcmp(const void *v1, const void *v2, size_t n) {
const v128_t *w1 = v1;
const v128_t *w2 = v2;
@@ -32,8 +34,8 @@ int memcmp(const void *v1, const void *v2, size_t n) {
w2++;
}
const unsigned char *u1 = (const void *)w1;
const unsigned char *u2 = (const void *)w2;
const uint8_t *u1 = (void *)w1;
const uint8_t *u2 = (void *)w2;
while (n--) {
if (*u1 != *u2) return *u1 - *u2;
u1++;
@@ -42,60 +44,58 @@ int memcmp(const void *v1, const void *v2, size_t n) {
return 0;
}
int strcmp(const char *c1, const char *c2) {
const v128_t *w1 = (void *)c1;
const v128_t *w2 = (void *)c2;
if (!(UNALIGNED(w1) | UNALIGNED(w2))) {
while (true) {
if (wasm_v128_any_true(*w1 ^ *w2)) {
break; // *w1 != *w2
}
if (!wasm_i8x16_all_true(*w1)) {
return 0; // *w1 == *w2 and have a NUL
}
w1++;
w2++;
}
}
const uint8_t *u1 = (void *)w1;
const uint8_t *u2 = (void *)w2;
while (true) {
if (*u1 != *u2) return *u1 - *u2;
if (*u1 == 0) break;
u1++;
u2++;
}
return 0;
}
int strncmp(const char *c1, const char *c2, size_t n) {
const v128_t *w1 = (const void *)c1;
const v128_t *w2 = (const void *)c2;
const v128_t *w1 = (void *)c1;
const v128_t *w2 = (void *)c2;
for (; n >= sizeof(v128_t); n -= sizeof(v128_t)) {
if (wasm_v128_any_true(wasm_v128_load(w1) ^ wasm_v128_load(w2))) {
break; // *w1 != *w2
}
if (!wasm_i8x16_all_true(wasm_v128_load(w1))) {
return 0; // *w1 == *w2 and they have a NUL
return 0; // *w1 == *w2 and have a NUL
}
w1++;
w2++;
}
c1 = (const void *)w1;
c2 = (const void *)w2;
while (n-- && *c1 == *c2) {
if (n == 0 || *c1 == 0) return 0;
c1++;
c2++;
const uint8_t *u1 = (void *)w1;
const uint8_t *u2 = (void *)w2;
while (n--) {
if (*u1 != *u2) return *u1 - *u2;
if (*u1 == 0) break;
u1++;
u2++;
}
return *(unsigned char *)c1 - *(unsigned char *)c2;
}
#endif
#define ONES (~(uintmax_t)(0) / UCHAR_MAX)
#define HIGHS (ONES * (UCHAR_MAX / 2 + 1))
#define HASZERO(x) ((x) - (typeof(x))(ONES) & ~(x) & (typeof(x))(HIGHS))
#define UNALIGNED(x) ((uintptr_t)(x) & (sizeof(*x) - 1))
int strcmp(const char *c1, const char *c2) {
typedef uintptr_t __attribute__((__may_alias__)) word;
const word *w1 = (const void *)c1;
const word *w2 = (const void *)c2;
if (!(UNALIGNED(w1) | UNALIGNED(w2))) {
while (*w1 == *w2) {
if (HASZERO(*w1)) return 0;
w1++;
w2++;
}
c1 = (const void *)w1;
c2 = (const void *)w2;
}
while (*c1 == *c2 && *c1) {
c1++;
c2++;
}
return *(unsigned char *)c1 - *(unsigned char *)c2;
return 0;
}
#undef UNALIGNED
#undef HASZERO
#undef HIGHS
#undef ONES
#endif