From d0309ab2d46698ff3350818c123a9eec43b41e56 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 10 Aug 2023 15:01:53 +0000 Subject: [PATCH] add avx2 intrinsics maybe --- Makefile | 9 +++++++++ run.c | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 8debdc6..31c228a 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,15 @@ runfast: run.c runomp: run.c $(CC) -Ofast -fopenmp -march=native run.c -lm -o run +# compile with AVX2 intrinsics enabled +.PHONY: runavx2 +runavx2: run.c + $(CC) -Ofast -march=native -mavx2 -DLLAMAC_AVX2 -o run run.c -lm + +.PHONY: runompavx2 + runompavx2: run.c + $(CC) -Ofast -fopenmp -march=native -mavx2 -DLLAMAC_AVX2 run.c -lm -o run + .PHONY: win64 win64: x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c diff --git a/run.c b/run.c index 9f4a1b2..fa5afd0 100644 --- a/run.c +++ b/run.c @@ -199,6 +199,45 @@ void softmax(float* x, int size) { } } +#ifdef LLAMAC_AVX2 +#include + +// AVX2 intrinsics for matmul +void matmul(float* xout, const float* x, const float* w, int n, int d) { + + int nn = n / 8 * 8; // ensure n is a multiple of 8 + int i; + #pragma omp parallel for private(i) + for (int i = 0; i < d; i++) { + __m256 sum_vec = _mm256_setzero_ps(); // for AVX2, sum of 8 floats + int i_n = i * n; + for (int j = 0; j < nn; j += 8) { + // Load 8 values from w and x + __m256 w_vec = _mm256_loadu_ps(&w[i_n + j]); + __m256 x_vec = _mm256_loadu_ps(&x[j]); + // Multiply and accumulate + __m256 prod_vec = _mm256_mul_ps(w_vec, x_vec); + sum_vec = _mm256_add_ps(sum_vec, prod_vec); + } + + // Perform horizontal add + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + float vals[8]; + _mm256_storeu_ps(vals, sum_vec); + float val = vals[0] + vals[4]; + + // handle remainder if n is not a multiple of 8 + for (int j = nn; j < n; j++) { + val += w[i_n + j] * x[j]; + } + xout[i] = val; + } +} + +#else + +// naive matmul void matmul(float* xout, float* x, float* w, int n, int d) { // W (d,n) @ x (n,) -> xout (d,) // by far the most amount of time is spent inside this little function @@ -206,12 +245,14 @@ void matmul(float* xout, float* x, float* w, int n, int d) { #pragma omp parallel for private(i) for (i = 0; i < d; i++) { float val = 0.0f; + int i_n = i * n; for (int j = 0; j < n; j++) { - val += w[i * n + j] * x[j]; + val += w[i_n + j] * x[j]; } xout[i] = val; } } +#endif void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {