optimize sample_topp by filtering out small value elements up front
This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front.
This commit is contained in:
@@ -465,17 +465,24 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
|||||||
// tokens that exceed probability topp. This way we never sample tokens that
|
// tokens that exceed probability topp. This way we never sample tokens that
|
||||||
// have very low probabilities and are less likely to go "off the rails".
|
// have very low probabilities and are less likely to go "off the rails".
|
||||||
|
|
||||||
|
int n0 = 0;
|
||||||
// quicksort indices in descending order of probabilities
|
// quicksort indices in descending order of probabilities
|
||||||
|
// elements smaller than (1 - topp) / (n - 1) cannot be part of the result
|
||||||
|
// and can be filtered out directly
|
||||||
|
const float cutoff = (1.0f - topp) / (n - 1);
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
probindex[i].index = i;
|
if (probabilities[i] >= cutoff) {
|
||||||
probindex[i].prob = probabilities[i];
|
probindex[n0].index = i;
|
||||||
|
probindex[n0].prob = probabilities[i];
|
||||||
|
n0++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
qsort(probindex, n, sizeof(ProbIndex), compare);
|
qsort(probindex, n0, sizeof(ProbIndex), compare);
|
||||||
|
|
||||||
// truncate the list where cumulative probability exceeds topp
|
// truncate the list where cumulative probability exceeds topp
|
||||||
float cumulative_prob = 0.0f;
|
float cumulative_prob = 0.0f;
|
||||||
int last_idx = 0;
|
int last_idx = n0 - 1; // in case of rounding errors consider all elements
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n0; i++) {
|
||||||
cumulative_prob += probindex[i].prob;
|
cumulative_prob += probindex[i].prob;
|
||||||
if (cumulative_prob > topp) {
|
if (cumulative_prob > topp) {
|
||||||
last_idx = i;
|
last_idx = i;
|
||||||
|
|||||||
Reference in New Issue
Block a user