Compare commits
1 Commits
master
...
feature/avx2
| Author | SHA1 | Date | |
|---|---|---|---|
| d0309ab2d4 |
@@ -32,6 +32,15 @@ runfast: run.c
|
|||||||
runomp: run.c
|
runomp: run.c
|
||||||
$(CC) -Ofast -fopenmp -march=native run.c -lm -o run
|
$(CC) -Ofast -fopenmp -march=native run.c -lm -o run
|
||||||
|
|
||||||
|
# compile with AVX2 intrinsics enabled
|
||||||
|
.PHONY: runavx2
|
||||||
|
runavx2: run.c
|
||||||
|
$(CC) -Ofast -march=native -mavx2 -DLLAMAC_AVX2 -o run run.c -lm
|
||||||
|
|
||||||
|
.PHONY: runompavx2
|
||||||
|
runompavx2: run.c
|
||||||
|
$(CC) -Ofast -fopenmp -march=native -mavx2 -DLLAMAC_AVX2 run.c -lm -o run
|
||||||
|
|
||||||
.PHONY: win64
|
.PHONY: win64
|
||||||
win64:
|
win64:
|
||||||
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
|
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
|
||||||
|
|||||||
@@ -199,6 +199,45 @@ void softmax(float* x, int size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef LLAMAC_AVX2
|
||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
// AVX2 intrinsics for matmul
|
||||||
|
void matmul(float* xout, const float* x, const float* w, int n, int d) {
|
||||||
|
|
||||||
|
int nn = n / 8 * 8; // ensure n is a multiple of 8
|
||||||
|
int i;
|
||||||
|
#pragma omp parallel for private(i)
|
||||||
|
for (int i = 0; i < d; i++) {
|
||||||
|
__m256 sum_vec = _mm256_setzero_ps(); // for AVX2, sum of 8 floats
|
||||||
|
int i_n = i * n;
|
||||||
|
for (int j = 0; j < nn; j += 8) {
|
||||||
|
// Load 8 values from w and x
|
||||||
|
__m256 w_vec = _mm256_loadu_ps(&w[i_n + j]);
|
||||||
|
__m256 x_vec = _mm256_loadu_ps(&x[j]);
|
||||||
|
// Multiply and accumulate
|
||||||
|
__m256 prod_vec = _mm256_mul_ps(w_vec, x_vec);
|
||||||
|
sum_vec = _mm256_add_ps(sum_vec, prod_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform horizontal add
|
||||||
|
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
|
||||||
|
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
|
||||||
|
float vals[8];
|
||||||
|
_mm256_storeu_ps(vals, sum_vec);
|
||||||
|
float val = vals[0] + vals[4];
|
||||||
|
|
||||||
|
// handle remainder if n is not a multiple of 8
|
||||||
|
for (int j = nn; j < n; j++) {
|
||||||
|
val += w[i_n + j] * x[j];
|
||||||
|
}
|
||||||
|
xout[i] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
// naive matmul
|
||||||
void matmul(float* xout, float* x, float* w, int n, int d) {
|
void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||||
// W (d,n) @ x (n,) -> xout (d,)
|
// W (d,n) @ x (n,) -> xout (d,)
|
||||||
// by far the most amount of time is spent inside this little function
|
// by far the most amount of time is spent inside this little function
|
||||||
@@ -206,12 +245,14 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
|
|||||||
#pragma omp parallel for private(i)
|
#pragma omp parallel for private(i)
|
||||||
for (i = 0; i < d; i++) {
|
for (i = 0; i < d; i++) {
|
||||||
float val = 0.0f;
|
float val = 0.0f;
|
||||||
|
int i_n = i * n;
|
||||||
for (int j = 0; j < n; j++) {
|
for (int j = 0; j < n; j++) {
|
||||||
val += w[i * n + j] * x[j];
|
val += w[i_n + j] * x[j];
|
||||||
}
|
}
|
||||||
xout[i] = val;
|
xout[i] = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user