slight tweaks to softmax
This commit is contained in:
@@ -196,10 +196,6 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void softmax(float* x, int size) {
|
void softmax(float* x, int size) {
|
||||||
if(size == 1) {
|
|
||||||
x[0] = 1.0f;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// find max value (for numerical stability)
|
// find max value (for numerical stability)
|
||||||
float max_val = x[0];
|
float max_val = x[0];
|
||||||
for (int i = 1; i < size; i++) {
|
for (int i = 1; i < size; i++) {
|
||||||
@@ -207,14 +203,13 @@ void softmax(float* x, int size) {
|
|||||||
max_val = x[i];
|
max_val = x[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// exp and sum
|
||||||
// normalize
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
x[i] = exp(x[i] - max_val);
|
x[i] = exp(x[i] - max_val);
|
||||||
sum += x[i];
|
sum += x[i];
|
||||||
}
|
}
|
||||||
|
// normalize
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
x[i] /= sum;
|
x[i] /= sum;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user