Kaynağa Gözat

sampling : optimize dist sampler (#15704)

ggml-ci
Georgi Gerganov 4 ay önce
ebeveyn
işleme
cdedb70a99
1 değiştirilmiş dosya ile 65 ekleme ve 2 silme
  1. 65 2
      src/llama-sampling.cpp

+ 65 - 2
src/llama-sampling.cpp

@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
 
-    // sorting is not necessary here
-    llama_sampler_softmax_impl(cur_p, false);
+    // edge cases
+    if (cur_p->size == 0) {
+        cur_p->selected = -1;
+        return;
+    }
+
+    cur_p->selected = 0;
+
+    if (cur_p->size == 1) {
+        cur_p->data[0].p = 1.0f;
+        return;
+    }
+
+    // max logit for numerical stability
+    float max_l = cur_p->data[0].logit;
+    if (!cur_p->sorted) {
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            max_l = std::max(max_l, cur_p->data[i].logit);
+        }
+    }
+
+    // apply softmax to obtain the probabilities
+    double sum_cum = 0.0f;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float p = expf(cur_p->data[i].logit - max_l);
+        cur_p->data[i].p = p;
+        sum_cum += p;
+    }
+
+#if 1
+    // sample from the obtained probabilities and normalize the probs in a single pass
+    // this is ~3x faster on Mac with full gpt-oss vocab than the version below
+    //
+    std::uniform_real_distribution<double> dist(0.0f, 1.0f);
+    const double rnd = dist(ctx->rng);
+
+          double sum_run = 0.0f;
+    const double sum_tgt = sum_cum*rnd;
+
+    bool found = false;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (!found) {
+            // accumulate probs until we reach the target sum
+            sum_run += cur_p->data[i].p;
+            if (sum_run >= sum_tgt) {
+                cur_p->selected = i;
+                found = true;
+            }
+        }
+
+        // normalize probs
+        cur_p->data[i].p /= sum_cum;
+    }
+
+    // fallback to the last token (don't think this can happen)
+    assert(found);
+    if (!found) {
+        cur_p->selected = cur_p->size - 1;
+    }
+#else
+    // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= sum_cum;
+    }
 
     cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+#endif
 }
 
 static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {