From d421a95b2bfe593b2d9e5c147f3efc8d128afe0e Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Sat, 12 Aug 2023 20:31:19 +0200 Subject: [PATCH] optimize sample_topp by filtering out small value elements up front This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front. --- run.c | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/run.c b/run.c index afe695f..9fd8f76 100644 --- a/run.c +++ b/run.c @@ -465,17 +465,24 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { // tokens that exceed probability topp. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". + int n0 = 0; // quicksort indices in descending order of probabilities + // elements smaller than (1 - topp) / (n - 1) cannot be part of the result + // and can be filtered out directly + const float cutoff = (1.0f - topp) / (n - 1); for (int i = 0; i < n; i++) { - probindex[i].index = i; - probindex[i].prob = probabilities[i]; + if (probabilities[i] >= cutoff) { + probindex[n0].index = i; + probindex[n0].prob = probabilities[i]; + n0++; + } } - qsort(probindex, n, sizeof(ProbIndex), compare); + qsort(probindex, n0, sizeof(ProbIndex), compare); // truncate the list where cumulative probability exceeds topp float cumulative_prob = 0.0f; - int last_idx = 0; - for (int i = 0; i < n; i++) { + int last_idx = n0 - 1; // in case of rounding errors consider all elements + for (int i = 0; i < n0; i++) { cumulative_prob += probindex[i].prob; if (cumulative_prob > topp) { last_idx = i;