add avx2 intrinsics maybe

This commit is contained in:
Andrej Karpathy
2023-08-10 15:01:53 +00:00
parent 3f69c6cdc4
commit d0309ab2d4
2 changed files with 51 additions and 1 deletions
+9
View File
@@ -32,6 +32,15 @@ runfast: run.c
runomp: run.c
$(CC) -Ofast -fopenmp -march=native run.c -lm -o run
# compile with AVX2 intrinsics enabled
.PHONY: runavx2
runavx2: run.c
$(CC) -Ofast -march=native -mavx2 -DLLAMAC_AVX2 -o run run.c -lm
.PHONY: runompavx2
runompavx2: run.c
$(CC) -Ofast -fopenmp -march=native -mavx2 -DLLAMAC_AVX2 run.c -lm -o run
.PHONY: win64
win64:
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
+42 -1
View File
@@ -199,6 +199,45 @@ void softmax(float* x, int size) {
}
}
#ifdef LLAMAC_AVX2
#include <immintrin.h>
// AVX2 intrinsics for matmul
void matmul(float* xout, const float* x, const float* w, int n, int d) {
int nn = n / 8 * 8; // ensure n is a multiple of 8
int i;
#pragma omp parallel for private(i)
for (int i = 0; i < d; i++) {
__m256 sum_vec = _mm256_setzero_ps(); // for AVX2, sum of 8 floats
int i_n = i * n;
for (int j = 0; j < nn; j += 8) {
// Load 8 values from w and x
__m256 w_vec = _mm256_loadu_ps(&w[i_n + j]);
__m256 x_vec = _mm256_loadu_ps(&x[j]);
// Multiply and accumulate
__m256 prod_vec = _mm256_mul_ps(w_vec, x_vec);
sum_vec = _mm256_add_ps(sum_vec, prod_vec);
}
// Perform horizontal add
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
float vals[8];
_mm256_storeu_ps(vals, sum_vec);
float val = vals[0] + vals[4];
// handle remainder if n is not a multiple of 8
for (int j = nn; j < n; j++) {
val += w[i_n + j] * x[j];
}
xout[i] = val;
}
}
#else
// naive matmul
void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
@@ -206,12 +245,14 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
int i_n = i * n;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
val += w[i_n + j] * x[j];
}
xout[i] = val;
}
}
#endif
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {