optimize sample_topp by filtering out small value elements up front
This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front.
This commit is contained in:
@@ -465,17 +465,24 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
// tokens that exceed probability topp. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
|
||||
int n0 = 0;
|
||||
// quicksort indices in descending order of probabilities
|
||||
// elements smaller than (1 - topp) / (n - 1) cannot be part of the result
|
||||
// and can be filtered out directly
|
||||
const float cutoff = (1.0f - topp) / (n - 1);
|
||||
for (int i = 0; i < n; i++) {
|
||||
probindex[i].index = i;
|
||||
probindex[i].prob = probabilities[i];
|
||||
if (probabilities[i] >= cutoff) {
|
||||
probindex[n0].index = i;
|
||||
probindex[n0].prob = probabilities[i];
|
||||
n0++;
|
||||
}
|
||||
}
|
||||
qsort(probindex, n, sizeof(ProbIndex), compare);
|
||||
qsort(probindex, n0, sizeof(ProbIndex), compare);
|
||||
|
||||
// truncate the list where cumulative probability exceeds topp
|
||||
float cumulative_prob = 0.0f;
|
||||
int last_idx = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
int last_idx = n0 - 1; // in case of rounding errors consider all elements
|
||||
for (int i = 0; i < n0; i++) {
|
||||
cumulative_prob += probindex[i].prob;
|
||||
if (cumulative_prob > topp) {
|
||||
last_idx = i;
|
||||
|
||||
Reference in New Issue
Block a user