From 9ca6f80da5e359b5bb4c46ea65c92caaacc6a64a Mon Sep 17 00:00:00 2001 From: Jiajie Chen Date: Thu, 14 Dec 2023 00:31:05 +0800 Subject: [PATCH] Prepare for lasx --- code/Makefile | 2 +- code/common.h | 133 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) diff --git a/code/Makefile b/code/Makefile index 3706ca83..fe930012 100644 --- a/code/Makefile +++ b/code/Makefile @@ -14,7 +14,7 @@ run-%: % clean: rm -rf $(EXES) -%: %.cpp %.h +%: %.cpp %.h common.h $(CXX) $< -mlsx -mlasx -o $@ .SUFFIXES: \ No newline at end of file diff --git a/code/common.h b/code/common.h index abea3a38..05a0c334 100644 --- a/code/common.h +++ b/code/common.h @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -130,6 +131,57 @@ union v128 { } }; +union v256 { + __m256i m128i; + __m256 m128; + __m256d m128d; + v4i64 __v4i64; + v4u64 __v4u64; + v8i32 __v8i32; + v8u32 __v8u32; + v16i16 __v16i16; + v16u16 __v16u16; + v32i8 __v32i8; + v32u8 __v32u8; + + u8 byte[32]; + u16 half[16]; + u32 word[8]; + u64 dword[4]; + u128 qword[2]; + + float fp32[8]; + double fp64[4]; + + v256(__m256i other) { m256i = other; } + v256(__m256d other) { m256d = other; } + v256(__m256 other) { m256 = other; } + v256() { + for (int i = 0; i < 16; i++) { + half[i] = rand(); + } + } + + operator __m256i() { return m128i; } + operator __m256() { return m128; } + operator __m256d() { return m128d; } + // duplicate with __m256i + // operator v4i64() { return __v4i64; } + operator v4u64() { return __v4u64; } + operator v8i32() { return __v8i32; } + operator v8u32() { return __v8u32; } + operator v16i16() { return __v16i16; } + operator v16u16() { return __v16u16; } + operator v32i8() { return __v32i8; } + operator v32u8() { return __v32u8; } + bool operator==(const v256 &other) const { + return memcmp(byte, other.byte, 32) == 0; + } + bool operator!=(const v256 &other) const { + return memcmp(byte, other.byte, 32) != 0; + } +}; + void test(); void print(const char *s, v128 num) { @@ -151,6 +203,30 @@ void print(const char *s, __m128d num) { printf("__m128d %s: %lf %lf\n", s, num[0], num[1]); } +void print(const char *s, v256 num) { + printf("v256 as __m256i %s: %016lx %016lx %016lx %016lx\n", s, num.dword[0], + num.dword[1], num.dword[2], num.dword[3]); + printf("v256 as __m256 %s: %f %f %f %f %f %f %f %f\n", s, num.fp32[0], + num.fp32[1], num.fp32[2], num.fp32[3], num.fp32[4], num.fp32[5], + num.fp32[6], num.fp32[7]); + printf("v256 as __m256d %s: %lf %lf %lf %lf\n", s, num.fp64[0], num.fp64[1], + num.fp64[2], num.fp64[3]); +} + +void print(const char *s, __m256i num) { + printf("__m256i %s: %016llx %016llx %016llx %016llx\n", s, num[0], num[1], + num[2], num[3]); +} + +void print(const char *s, __m256 num) { + printf("__m256 %s: %f %f %f %f %f %f %f %f\n", s, num[0], num[1], num[2], + num[3], num[4], num[5], num[6], num[7]); +} + +void print(const char *s, __m256d num) { + printf("__m256d %s: %lf %lf %lf %lf\n", s, num[0], num[1], num[2], num[3]); +} + #define PRINT(x) print(#x, x) #define FUZZ_N 128 @@ -212,6 +288,63 @@ void print(const char *s, __m128d num) { } \ } while (0); +#define XFUZZ0(func, ...) \ + do { \ + for (int i = 0; i < XFUZZ_N; i++) { \ + if (func(__VA_ARGS__) != __lasx_##func(__VA_ARGS__)) { \ + PRINT(__lasx_##func(__VA_ARGS__)); \ + PRINT(func(__VA_ARGS__)); \ + assert(func(__VA_ARGS__) == __lasx_##func(__VA_ARGS__)); \ + } \ + } \ + } while (0); + +#define XFUZZ1(func, ...) \ + do { \ + for (int i = 0; i < XFUZZ_N; i++) { \ + v256 a; \ + if (func(a __VA_OPT__(, ) __VA_ARGS__) != \ + __lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)) { \ + PRINT(a); \ + PRINT(__lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)); \ + PRINT(func(a __VA_OPT__(, ) __VA_ARGS__)); \ + assert(func(a __VA_OPT__(, ) __VA_ARGS__) == \ + __lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)); \ + } \ + } \ + } while (0); + +#define XFUZZ2(func, ...) \ + do { \ + for (int i = 0; i < XFUZZ_N; i++) { \ + v256 a, b; \ + if (func(a, b __VA_OPT__(, ) __VA_ARGS__) != \ + __lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)) { \ + PRINT(a); \ + PRINT(b); \ + PRINT(__lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)); \ + PRINT(func(a, b __VA_OPT__(, ) __VA_ARGS__)); \ + assert(func(a, b __VA_OPT__(, ) __VA_ARGS__) == \ + __lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)); \ + } \ + } \ + } while (0); + +#define XFUZZ3(func) \ + do { \ + for (int i = 0; i < XFUZZ_N; i++) { \ + v256 a, b, c; \ + if (func(a, b, c) != __lasx_##func(a, b, c)) { \ + PRINT(a); \ + PRINT(b); \ + PRINT(c); \ + PRINT(__lasx_##func(a, b, c)); \ + PRINT(func(a, b, c)); \ + assert(func(a, b, c) == __lasx_##func(a, b, c)); \ + } \ + } \ + } while (0); + int main(int argc, char *argv[]) { printf("Testing %s\n", argv[0]); test();