فهرست منبع

llama : always sort logits before nucleus sampling (#812)

* Always sort logits before nucleus sampling

* remove second normalization

- fix windows build
- remove normalization since std::discrete_distribution does not require it
Ivan Stepanov 2 سال پیش
والد
کامیت
4953e9007f
1فایلهای تغییر یافته به همراه3 افزوده شده و 14 حذف شده
  1. 3 14
      llama.cpp

+ 3 - 14
llama.cpp

@@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
         }
     }
 
-    if (top_k > 0 && top_k < n_logits) {
-        sample_top_k(logits_id, top_k);
-    }
-
-    float maxl = -std::numeric_limits<float>::infinity();
-    for (const auto & kv : logits_id) {
-        maxl = Max(maxl, kv.first);
-    }
+    sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
 
     // compute probs for the top k tokens
     std::vector<float> probs;
     probs.reserve(logits_id.size());
 
+    float maxl = logits_id[0].first;
     double sum = 0.0;
     for (const auto & kv : logits_id) {
         const float p = expf(kv.first - maxl);
@@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
                 break;
             }
         }
-
-        cumsum = 1.0/cumsum;
-        for (int i = 0; i < (int) probs.size(); i++) {
-            probs[i] *= cumsum;
-        }
     }
 
     //printf("\n");
     //for (int i = 0; i < (int) 10; i++) {
-    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+    //    printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
     //}
     //printf("\n\n");
     //exit(0);