Procházet zdrojové kódy

hellaswag: display estimated score confidence interval (#12797)

stduhpf před 9 měsíci
rodič
revize
4ccea213bc
1 změnil soubory, kde provedl 17 přidání a 3 odebrání
  1. 17 3
      examples/perplexity/perplexity.cpp

+ 17 - 3
examples/perplexity/perplexity.cpp

@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 
     LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
 
-    LOG("\ntask\tacc_norm\n");
+    LOG("\ntask\tacc_norm\t95%% confidence interval\n");
 
     double acc = 0.0f;
 
@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
                 acc += 1.0;
             }
 
-            // Print the accumulated accuracy mean x 100
-            LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
+            double freq = acc / double(i + 1);
+
+            const double za = 1.95996398454;
+
+            // // Wald normal approx
+            // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
+            // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
+
+            // Wilson score interval, more accurate
+            double z   = za * za / double(i + 1);
+            double cnf = z * sqrt(double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
+            double a   = (freq + z * 0.5 - cnf) / (1.0 + z);
+            double b   = (freq + z * 0.5 + cnf) / (1.0 + z);
+
+            // Print the accumulated accuracy mean x 100 and confidence interval
+            LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
         }
 
         i0 = i1 - 1;