فهرست منبع

sampling : don't consider -infinity values in top_n_sigma (#13344)

oobabooga 8 ماه پیش
والد
کامیت
91a86a6f35
1فایلهای تغییر یافته به همراه14 افزوده شده و 6 حذف شده
  1. 14 6
      src/llama-sampling.cpp

+ 14 - 6
src/llama-sampling.cpp

@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
     // find max logit and calculate mean
     float max = cur_p->data[0].logit;
     float logits_sum = 0;
+    size_t valid_count = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (cur_p->data[i].logit > max) {
-            max = cur_p->data[i].logit;
+        // Only count non-negative infinity values
+        if (cur_p->data[i].logit != -INFINITY) {
+            if (cur_p->data[i].logit > max) {
+                max = cur_p->data[i].logit;
+            }
+            logits_sum += cur_p->data[i].logit;
+            valid_count++;
         }
-        logits_sum += cur_p->data[i].logit;
     }
-    float mean = logits_sum/cur_p->size;
+    float mean = valid_count > 0 ? logits_sum/valid_count : 0;
 
     // calculate standard deviation
     float acc = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        acc += pow(cur_p->data[i].logit - mean, 2);
+        // Skip -infinity in std calculation
+        if (cur_p->data[i].logit != -INFINITY) {
+            acc += pow(cur_p->data[i].logit - mean, 2);
+        }
     }
-    float std = sqrt(acc/cur_p->size);
+    float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
 
     //apply mask
     for (size_t i = 0; i < cur_p->size; ++i) {