From 7f7a3b2d56acb7422f4c1858954bf55ea9b94095 Mon Sep 17 00:00:00 2001 From: richinseattle Date: Wed, 26 Jul 2023 22:06:23 -0700 Subject: [PATCH 1/2] update openmp pragmas for MSVC compatibility This has no negative impact on Linux and is in preparation for windows support. Windows compiles will not work without additional timer and mmap compatibility patches --- run.c | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/run.c b/run.c index 049b071..9d3c991 100644 --- a/run.c +++ b/run.c @@ -13,9 +13,11 @@ $ ./run #include #include #include -#include #include +#ifndef _WIN32 +#include #include +#endif // ---------------------------------------------------------------------------- // Transformer and RunState structs, and related memory management @@ -190,9 +192,9 @@ void softmax(float* x, int size) { void matmul(float* xout, float* x, float* w, int n, int d) { // W (d,n) @ x (n,) -> xout (d,) - #pragma omp parallel for - for (int i = 0; i < d; i++) { - float val = 0.0f; + int i; + #pragma omp parallel for private(i) + for (i = 0; i < d; i++) { float val = 0.0f; for (int j = 0; j < n; j++) { val += w[i * n + j] * x[j]; } @@ -255,8 +257,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); // multihead attention. iterate over all heads - #pragma omp parallel for - for (int h = 0; h < p->n_heads; h++) { + int h; + #pragma omp parallel for private(h) + for (h = 0; h < p->n_heads; h++) { // get the query vector for this head float* q = s->q + h * head_size; // attention scores for this head From 539dc73196e5f5c0aa8c38c924cb02fef7fadd29 Mon Sep 17 00:00:00 2001 From: richinseattle Date: Wed, 26 Jul 2023 22:12:32 -0700 Subject: [PATCH 2/2] fix whitespace --- run.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/run.c b/run.c index 9d3c991..940db6f 100644 --- a/run.c +++ b/run.c @@ -194,7 +194,8 @@ void matmul(float* xout, float* x, float* w, int n, int d) { // W (d,n) @ x (n,) -> xout (d,) int i; #pragma omp parallel for private(i) - for (i = 0; i < d; i++) { float val = 0.0f; + for (i = 0; i < d; i++) { + float val = 0.0f; for (int j = 0; j < n; j++) { val += w[i * n + j] * x[j]; }