reshuffle blocks of code a bit
This commit is contained in:
@@ -337,46 +337,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
|
||||||
unsigned long long rng_seed;
|
|
||||||
unsigned int random_u32() {
|
|
||||||
rng_seed ^= rng_seed >> 12;
|
|
||||||
rng_seed ^= rng_seed << 25;
|
|
||||||
rng_seed ^= rng_seed >> 27;
|
|
||||||
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
|
||||||
}
|
|
||||||
float random_f32() {
|
|
||||||
return (random_u32() >> 8) / 16777216.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// functions to sample the next token from the transformer's predicted distribution
|
|
||||||
|
|
||||||
int sample(float* probabilities, int n) {
|
|
||||||
// sample index from probabilities, they must sum to 1
|
|
||||||
float r = random_f32();
|
|
||||||
float cdf = 0.0f;
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
cdf += probabilities[i];
|
|
||||||
if (r < cdf) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return n - 1; // in case of rounding errors
|
|
||||||
}
|
|
||||||
|
|
||||||
int argmax(float* v, int n) {
|
|
||||||
// return argmax of v in elements 0..n
|
|
||||||
int max_i = 0;
|
|
||||||
float max_p = v[0];
|
|
||||||
for (int i = 1; i < n; i++) {
|
|
||||||
if (v[i] > max_p) {
|
|
||||||
max_i = i;
|
|
||||||
max_p = v[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max_i;
|
|
||||||
}
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||||
|
|
||||||
@@ -441,11 +401,51 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// utilities
|
// utilities
|
||||||
|
|
||||||
long time_in_ms() {
|
long time_in_ms() {
|
||||||
|
// return time in milliseconds, for benchmarking the model speed
|
||||||
struct timespec time;
|
struct timespec time;
|
||||||
clock_gettime(CLOCK_REALTIME, &time);
|
clock_gettime(CLOCK_REALTIME, &time);
|
||||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned long long rng_seed;
|
||||||
|
unsigned int random_u32() {
|
||||||
|
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||||
|
rng_seed ^= rng_seed >> 12;
|
||||||
|
rng_seed ^= rng_seed << 25;
|
||||||
|
rng_seed ^= rng_seed >> 27;
|
||||||
|
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
||||||
|
}
|
||||||
|
float random_f32() { // random float32 in [0,1)
|
||||||
|
return (random_u32() >> 8) / 16777216.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample(float* probabilities, int n) {
|
||||||
|
// sample index from probabilities, they must sum to 1
|
||||||
|
float r = random_f32();
|
||||||
|
float cdf = 0.0f;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
cdf += probabilities[i];
|
||||||
|
if (r < cdf) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n - 1; // in case of rounding errors
|
||||||
|
}
|
||||||
|
|
||||||
|
int argmax(float* v, int n) {
|
||||||
|
// return argmax of v in elements 0..n
|
||||||
|
int max_i = 0;
|
||||||
|
float max_p = v[0];
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
if (v[i] > max_p) {
|
||||||
|
max_i = i;
|
||||||
|
max_p = v[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_i;
|
||||||
|
}
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|||||||
Reference in New Issue
Block a user