Skip to content

Commit

Permalink
Prepare for lasx
Browse files Browse the repository at this point in the history
  • Loading branch information
jiegec committed Dec 13, 2023
1 parent afa4adf commit 9ca6f80
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 1 deletion.
2 changes: 1 addition & 1 deletion code/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ run-%: %
clean:
rm -rf $(EXES)

%: %.cpp %.h
%: %.cpp %.h common.h
$(CXX) $< -mlsx -mlasx -o $@

.SUFFIXES:
133 changes: 133 additions & 0 deletions code/common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <algorithm>
#include <assert.h>
#include <lasxintrin.h>
#include <limits>
#include <lsxintrin.h>
#include <stdint.h>
Expand Down Expand Up @@ -130,6 +131,57 @@ union v128 {
}
};

union v256 {
__m256i m128i;
__m256 m128;
__m256d m128d;
v4i64 __v4i64;
v4u64 __v4u64;
v8i32 __v8i32;
v8u32 __v8u32;
v16i16 __v16i16;
v16u16 __v16u16;
v32i8 __v32i8;
v32u8 __v32u8;

u8 byte[32];
u16 half[16];
u32 word[8];
u64 dword[4];
u128 qword[2];

float fp32[8];
double fp64[4];

v256(__m256i other) { m256i = other; }
v256(__m256d other) { m256d = other; }
v256(__m256 other) { m256 = other; }
v256() {
for (int i = 0; i < 16; i++) {
half[i] = rand();
}
}

operator __m256i() { return m128i; }
operator __m256() { return m128; }
operator __m256d() { return m128d; }
// duplicate with __m256i
// operator v4i64() { return __v4i64; }
operator v4u64() { return __v4u64; }
operator v8i32() { return __v8i32; }
operator v8u32() { return __v8u32; }
operator v16i16() { return __v16i16; }
operator v16u16() { return __v16u16; }
operator v32i8() { return __v32i8; }
operator v32u8() { return __v32u8; }
bool operator==(const v256 &other) const {
return memcmp(byte, other.byte, 32) == 0;
}
bool operator!=(const v256 &other) const {
return memcmp(byte, other.byte, 32) != 0;
}
};

void test();

void print(const char *s, v128 num) {
Expand All @@ -151,6 +203,30 @@ void print(const char *s, __m128d num) {
printf("__m128d %s: %lf %lf\n", s, num[0], num[1]);
}

void print(const char *s, v256 num) {
printf("v256 as __m256i %s: %016lx %016lx %016lx %016lx\n", s, num.dword[0],
num.dword[1], num.dword[2], num.dword[3]);
printf("v256 as __m256 %s: %f %f %f %f %f %f %f %f\n", s, num.fp32[0],
num.fp32[1], num.fp32[2], num.fp32[3], num.fp32[4], num.fp32[5],
num.fp32[6], num.fp32[7]);
printf("v256 as __m256d %s: %lf %lf %lf %lf\n", s, num.fp64[0], num.fp64[1],
num.fp64[2], num.fp64[3]);
}

void print(const char *s, __m256i num) {
printf("__m256i %s: %016llx %016llx %016llx %016llx\n", s, num[0], num[1],
num[2], num[3]);
}

void print(const char *s, __m256 num) {
printf("__m256 %s: %f %f %f %f %f %f %f %f\n", s, num[0], num[1], num[2],
num[3], num[4], num[5], num[6], num[7]);
}

void print(const char *s, __m256d num) {
printf("__m256d %s: %lf %lf %lf %lf\n", s, num[0], num[1], num[2], num[3]);
}

#define PRINT(x) print(#x, x)

#define FUZZ_N 128
Expand Down Expand Up @@ -212,6 +288,63 @@ void print(const char *s, __m128d num) {
} \
} while (0);

#define XFUZZ0(func, ...) \
do { \
for (int i = 0; i < XFUZZ_N; i++) { \
if (func(__VA_ARGS__) != __lasx_##func(__VA_ARGS__)) { \
PRINT(__lasx_##func(__VA_ARGS__)); \
PRINT(func(__VA_ARGS__)); \
assert(func(__VA_ARGS__) == __lasx_##func(__VA_ARGS__)); \
} \
} \
} while (0);

#define XFUZZ1(func, ...) \
do { \
for (int i = 0; i < XFUZZ_N; i++) { \
v256 a; \
if (func(a __VA_OPT__(, ) __VA_ARGS__) != \
__lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)) { \
PRINT(a); \
PRINT(__lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)); \
PRINT(func(a __VA_OPT__(, ) __VA_ARGS__)); \
assert(func(a __VA_OPT__(, ) __VA_ARGS__) == \
__lasx_##func(a __VA_OPT__(, ) __VA_ARGS__)); \
} \
} \
} while (0);

#define XFUZZ2(func, ...) \
do { \
for (int i = 0; i < XFUZZ_N; i++) { \
v256 a, b; \
if (func(a, b __VA_OPT__(, ) __VA_ARGS__) != \
__lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)) { \
PRINT(a); \
PRINT(b); \
PRINT(__lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)); \
PRINT(func(a, b __VA_OPT__(, ) __VA_ARGS__)); \
assert(func(a, b __VA_OPT__(, ) __VA_ARGS__) == \
__lasx_##func(a, b __VA_OPT__(, ) __VA_ARGS__)); \
} \
} \
} while (0);

#define XFUZZ3(func) \
do { \
for (int i = 0; i < XFUZZ_N; i++) { \
v256 a, b, c; \
if (func(a, b, c) != __lasx_##func(a, b, c)) { \
PRINT(a); \
PRINT(b); \
PRINT(c); \
PRINT(__lasx_##func(a, b, c)); \
PRINT(func(a, b, c)); \
assert(func(a, b, c) == __lasx_##func(a, b, c)); \
} \
} \
} while (0);

int main(int argc, char *argv[]) {
printf("Testing %s\n", argv[0]);
test();
Expand Down

0 comments on commit 9ca6f80

Please sign in to comment.