Compare commits
1 Commits
master
...
feature/avx2
| Author | SHA1 | Date | |
|---|---|---|---|
| d0309ab2d4 |
@@ -32,6 +32,15 @@ runfast: run.c
|
||||
runomp: run.c
|
||||
$(CC) -Ofast -fopenmp -march=native run.c -lm -o run
|
||||
|
||||
# compile with AVX2 intrinsics enabled
|
||||
.PHONY: runavx2
|
||||
runavx2: run.c
|
||||
$(CC) -Ofast -march=native -mavx2 -DLLAMAC_AVX2 -o run run.c -lm
|
||||
|
||||
.PHONY: runompavx2
|
||||
runompavx2: run.c
|
||||
$(CC) -Ofast -fopenmp -march=native -mavx2 -DLLAMAC_AVX2 run.c -lm -o run
|
||||
|
||||
.PHONY: win64
|
||||
win64:
|
||||
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
|
||||
|
||||
@@ -199,6 +199,45 @@ void softmax(float* x, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef LLAMAC_AVX2
|
||||
#include <immintrin.h>
|
||||
|
||||
// AVX2 intrinsics for matmul
|
||||
void matmul(float* xout, const float* x, const float* w, int n, int d) {
|
||||
|
||||
int nn = n / 8 * 8; // ensure n is a multiple of 8
|
||||
int i;
|
||||
#pragma omp parallel for private(i)
|
||||
for (int i = 0; i < d; i++) {
|
||||
__m256 sum_vec = _mm256_setzero_ps(); // for AVX2, sum of 8 floats
|
||||
int i_n = i * n;
|
||||
for (int j = 0; j < nn; j += 8) {
|
||||
// Load 8 values from w and x
|
||||
__m256 w_vec = _mm256_loadu_ps(&w[i_n + j]);
|
||||
__m256 x_vec = _mm256_loadu_ps(&x[j]);
|
||||
// Multiply and accumulate
|
||||
__m256 prod_vec = _mm256_mul_ps(w_vec, x_vec);
|
||||
sum_vec = _mm256_add_ps(sum_vec, prod_vec);
|
||||
}
|
||||
|
||||
// Perform horizontal add
|
||||
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
|
||||
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
|
||||
float vals[8];
|
||||
_mm256_storeu_ps(vals, sum_vec);
|
||||
float val = vals[0] + vals[4];
|
||||
|
||||
// handle remainder if n is not a multiple of 8
|
||||
for (int j = nn; j < n; j++) {
|
||||
val += w[i_n + j] * x[j];
|
||||
}
|
||||
xout[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// naive matmul
|
||||
void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||
// W (d,n) @ x (n,) -> xout (d,)
|
||||
// by far the most amount of time is spent inside this little function
|
||||
@@ -206,12 +245,14 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||
#pragma omp parallel for private(i)
|
||||
for (i = 0; i < d; i++) {
|
||||
float val = 0.0f;
|
||||
int i_n = i * n;
|
||||
for (int j = 0; j < n; j++) {
|
||||
val += w[i * n + j] * x[j];
|
||||
val += w[i_n + j] * x[j];
|
||||
}
|
||||
xout[i] = val;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user