|
|
@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|
|
// find max logit and calculate mean
|
|
|
float max = cur_p->data[0].logit;
|
|
|
float logits_sum = 0;
|
|
|
+ size_t valid_count = 0;
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
- if (cur_p->data[i].logit > max) {
|
|
|
- max = cur_p->data[i].logit;
|
|
|
+ // Only count non-negative infinity values
|
|
|
+ if (cur_p->data[i].logit != -INFINITY) {
|
|
|
+ if (cur_p->data[i].logit > max) {
|
|
|
+ max = cur_p->data[i].logit;
|
|
|
+ }
|
|
|
+ logits_sum += cur_p->data[i].logit;
|
|
|
+ valid_count++;
|
|
|
}
|
|
|
- logits_sum += cur_p->data[i].logit;
|
|
|
}
|
|
|
- float mean = logits_sum/cur_p->size;
|
|
|
+ float mean = valid_count > 0 ? logits_sum/valid_count : 0;
|
|
|
|
|
|
// calculate standard deviation
|
|
|
float acc = 0;
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
|
- acc += pow(cur_p->data[i].logit - mean, 2);
|
|
|
+ // Skip -infinity in std calculation
|
|
|
+ if (cur_p->data[i].logit != -INFINITY) {
|
|
|
+ acc += pow(cur_p->data[i].logit - mean, 2);
|
|
|
+ }
|
|
|
}
|
|
|
- float std = sqrt(acc/cur_p->size);
|
|
|
+ float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
|
|
|
|
|
|
//apply mask
|
|
|
for (size_t i = 0; i < cur_p->size; ++i) {
|