[openmp] 1.5x inference speedup
Problem: - clock is CPU and doesn't work properly with parallel execution. - perf execution is matmul x weights bound. Solution: - use gettimeofday instead. - utilize openmp to parallelize matmul. Note: - if not compiled with -fopenmp the #pragma is ignored and single execution is performed. - there are additional env variable to setup for openmp (optinally) to setup the number of threads, scheduler etc. Benchmarks: ``` clang -Ofast -march=native run.c -lm -o run // achieved tok/s: 340.878828 clang -Ofast -fopenmp -march=native run.c -lm -o run // achieved tok/s: 524.590164 ```
This commit is contained in:
@@ -13,6 +13,7 @@ $ ./run
|
|||||||
#include <time.h>
|
#include <time.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Transformer and RunState structs, and related memory management
|
// Transformer and RunState structs, and related memory management
|
||||||
@@ -212,6 +213,7 @@ void softmax(float* x, int size) {
|
|||||||
|
|
||||||
void matmul(float* xout, float* x, float* w, int n, int d) {
|
void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||||
// W (d,n) @ x (n,) -> xout (d,)
|
// W (d,n) @ x (n,) -> xout (d,)
|
||||||
|
#pragma omp parallel for
|
||||||
for (int i = 0; i < d; i++) {
|
for (int i = 0; i < d; i++) {
|
||||||
float val = 0.0f;
|
float val = 0.0f;
|
||||||
for (int j = 0; j < n; j++) {
|
for (int j = 0; j < n; j++) {
|
||||||
@@ -372,6 +374,12 @@ int argmax(float* v, int n) {
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
long time_in_ms() {
|
||||||
|
struct timeval time;
|
||||||
|
gettimeofday(&time, NULL);
|
||||||
|
return time.tv_sec * 1000 + time.tv_usec / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
// poor man's C argparse
|
// poor man's C argparse
|
||||||
@@ -438,7 +446,8 @@ int main(int argc, char *argv[]) {
|
|||||||
malloc_run_state(&state, &config);
|
malloc_run_state(&state, &config);
|
||||||
|
|
||||||
// the current position we are in
|
// the current position we are in
|
||||||
clock_t start = clock();
|
long start = time_in_ms();
|
||||||
|
|
||||||
int next;
|
int next;
|
||||||
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
@@ -469,9 +478,8 @@ int main(int argc, char *argv[]) {
|
|||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
// report our achieved tok/s
|
// report our achieved tok/s
|
||||||
clock_t end = clock();
|
long end = time_in_ms();
|
||||||
double elapsed = (double)(end - start) / CLOCKS_PER_SEC;
|
printf("achieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
||||||
printf("achieved tok/s: %f\n", config.seq_len / elapsed);
|
|
||||||
|
|
||||||
// memory cleanup
|
// memory cleanup
|
||||||
free_run_state(&state);
|
free_run_state(&state);
|
||||||
|
|||||||
Reference in New Issue
Block a user