8
8
#include <random>
9
9
#include <stdlib.h>
10
10
11
-
#define EXPECT_MATRIX_EQ(A, B, R, C) \
12
-
do { \
13
-
for (unsigned r = 0; r < R; r++) \
14
-
for (unsigned c = 0; c < C; c++) \
15
-
if (A[r + c * R] != B[r + c * R]) { \
16
-
std::cerr << "mismatch at " << r << ":" << c << "\n"; \
17
-
exit(1); \
18
-
} \
19
-
} while (false)
11
+
#define ABSTOL 0.000001
12
+
#define RELTOL 0.00001
13
+
bool fpcmp(double V1, double V2, double AbsTolerance, double RelTolerance) {
14
+
// Check to see if these are inside the absolute tolerance
15
+
if (AbsTolerance < fabs(V1 - V2)) {
16
+
// Nope, check the relative tolerance...
17
+
double Diff;
18
+
if (V2)
19
+
Diff = fabs(V1 / V2 - 1.0);
20
+
else if (V1)
21
+
Diff = fabs(V2 / V1 - 1.0);
22
+
else
23
+
Diff = 0; // Both zero.
24
+
if (Diff > RelTolerance) {
25
+
return true;
26
+
}
27
+
}
28
+
return false;
29
+
}
30
+
31
+
template <typename ElementTy, typename std::enable_if_t<
32
+
std::is_integral<ElementTy>::value, int> = 0>
33
+
void expectMatrixEQ(ElementTy *A, ElementTy *B, unsigned R, unsigned C) {
34
+
do {
35
+
for (unsigned r = 0; r < R; r++)
36
+
for (unsigned c = 0; c < C; c++)
37
+
if (A[r + c * R] != B[r + c * R]) {
38
+
std::cerr << "mismatch at " << r << ":" << c << "\n";
39
+
exit(1);
40
+
}
41
+
} while (false);
42
+
}
43
+
44
+
template <typename ElementTy,
45
+
typename std::enable_if_t<std::is_floating_point<ElementTy>::value,
46
+
int> = 0>
47
+
void expectMatrixEQ(ElementTy *A, ElementTy *B, unsigned R,
48
+
unsigned C) {
49
+
do {
50
+
for (unsigned r = 0; r < R; r++)
51
+
for (unsigned c = 0; c < C; c++)
52
+
if (fpcmp(A[r + c * R], B[r + c * R], ABSTOL, RELTOL)) {
53
+
std::cerr << "mismatch at " << r << ":" << c << "\n";
54
+
exit(1);
55
+
}
56
+
} while (false);
57
+
}
58
+
20
59
21
60
template <typename EltTy>
22
61
void zeroMatrix(EltTy *M, unsigned Rows, unsigned Cols) {
@@ -33,13 +72,25 @@ template <typename EltTy> void print(EltTy *X, unsigned Rows, unsigned Cols) {
33
72
}
34
73
}
35
74
36
-
template <typename Ty> void initRandom(Ty *A, unsigned Rows, unsigned Cols) {
75
+
template <typename ElementTy,
76
+
typename std::enable_if_t<std::is_floating_point<ElementTy>::value,
77
+
int> = 0>
78
+
void initRandom(ElementTy *A, unsigned Rows, unsigned Cols) {
79
+
std::default_random_engine generator;
80
+
std::uniform_real_distribution<ElementTy> distribution(-10.0, 10.0);
81
+
82
+
for (unsigned i = 0; i < Rows * Cols; i++)
83
+
A[i] = distribution(generator);
84
+
}
85
+
86
+
template <typename ElementTy, typename std::enable_if_t<
87
+
std::is_integral<ElementTy>::value, int> = 0>
88
+
void initRandom(ElementTy *A, unsigned Rows, unsigned Cols) {
37
89
std::default_random_engine generator;
38
-
std::uniform_int_distribution<double> distribution(-10.0, 10.0);
39
-
auto random_double = std::bind(distribution, generator);
90
+
std::uniform_int_distribution<ElementTy> distribution(-10, 10);
40
91
41
92
for (unsigned i = 0; i < Rows * Cols; i++)
42
-
A[i] = random_double();
93
+
A[i] = distribution(generator);
43
94
}
44
95
45
96
template <typename EltTy, unsigned R, unsigned C>
@@ -82,8 +133,8 @@ template <typename EltTy, unsigned R0, unsigned C0> void testTranspose() {
82
133
transposeSpec<EltTy, R0, C0>(ResSpec, X);
83
134
transposeBuiltin<EltTy, R0, C0>(ResBuiltin, X);
84
135
85
-
EXPECT_MATRIX_EQ(ResBase, ResBuiltin, R0, C0);
86
-
EXPECT_MATRIX_EQ(ResBase, ResSpec, C0, R0);
136
+
expectMatrixEQ(ResBase, ResBuiltin, R0, C0);
137
+
expectMatrixEQ(ResBase, ResSpec, C0, R0);
87
138
}
88
139
89
140
template <typename EltTy, unsigned R0, unsigned C0, unsigned C1>
@@ -150,9 +201,9 @@ void testMultiply() {
150
201
multiplySpec<EltTy, R0, C0, C1>(ResSpec, X, Y);
151
202
multiplyBuiltin<EltTy, R0, C0, C1>(ResBuiltin, X, Y);
152
203
153
-
EXPECT_MATRIX_EQ(ResSpec, ResBuiltin, R0, C1);
154
-
EXPECT_MATRIX_EQ(ResBase, ResBuiltin, R0, C1);
155
-
EXPECT_MATRIX_EQ(ResBase, ResSpec, R0, C1);
204
+
expectMatrixEQ(ResSpec, ResBuiltin, R0, C1);
205
+
expectMatrixEQ(ResBase, ResBuiltin, R0, C1);
206
+
expectMatrixEQ(ResBase, ResSpec, R0, C1);
156
207
}
157
208
158
209
int main(void) {
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4