optimize sample_topp by filtering out small value elements up front

This works because we know that in worst case only 1 element will be selected
and therefore the remaining (n-1) elements have to split the remaining (1-topp)
probability. Probabilities smaller than that cannot be selected and can
be filtered out up front.
This commit is contained in:
Johannes Rudolph
2023-08-12 20:31:19 +02:00
parent c42641205f
commit d421a95b2b
+12 -5
View File
@@ -465,17 +465,24 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
int n0 = 0;
// quicksort indices in descending order of probabilities
// elements smaller than (1 - topp) / (n - 1) cannot be part of the result
// and can be filtered out directly
const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
probindex[i].index = i;
probindex[i].prob = probabilities[i];
if (probabilities[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probabilities[i];
n0++;
}
}
qsort(probindex, n, sizeof(ProbIndex), compare);
qsort(probindex, n0, sizeof(ProbIndex), compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = 0;
for (int i = 0; i < n; i++) {
int last_idx = n0 - 1; // in case of rounding errors consider all elements
for (int i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;