update openmp pragmas for MSVC compatibility
This has no negative impact on Linux and is in preparation for windows support. Windows compiles will not work without additional timer and mmap compatibility patches
This commit is contained in:
@@ -13,9 +13,11 @@ $ ./run
|
|||||||
#include <time.h>
|
#include <time.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <unistd.h>
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
|
#ifndef _WIN32
|
||||||
|
#include <unistd.h>
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Transformer and RunState structs, and related memory management
|
// Transformer and RunState structs, and related memory management
|
||||||
@@ -190,9 +192,9 @@ void softmax(float* x, int size) {
|
|||||||
|
|
||||||
void matmul(float* xout, float* x, float* w, int n, int d) {
|
void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||||
// W (d,n) @ x (n,) -> xout (d,)
|
// W (d,n) @ x (n,) -> xout (d,)
|
||||||
#pragma omp parallel for
|
int i;
|
||||||
for (int i = 0; i < d; i++) {
|
#pragma omp parallel for private(i)
|
||||||
float val = 0.0f;
|
for (i = 0; i < d; i++) { float val = 0.0f;
|
||||||
for (int j = 0; j < n; j++) {
|
for (int j = 0; j < n; j++) {
|
||||||
val += w[i * n + j] * x[j];
|
val += w[i * n + j] * x[j];
|
||||||
}
|
}
|
||||||
@@ -255,8 +257,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
||||||
|
|
||||||
// multihead attention. iterate over all heads
|
// multihead attention. iterate over all heads
|
||||||
#pragma omp parallel for
|
int h;
|
||||||
for (int h = 0; h < p->n_heads; h++) {
|
#pragma omp parallel for private(h)
|
||||||
|
for (h = 0; h < p->n_heads; h++) {
|
||||||
// get the query vector for this head
|
// get the query vector for this head
|
||||||
float* q = s->q + h * head_size;
|
float* q = s->q + h * head_size;
|
||||||
// attention scores for this head
|
// attention scores for this head
|
||||||
|
|||||||
Reference in New Issue
Block a user