|
@@ -380,6 +380,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
|
|
|
|
|
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
|
|
|
|
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
|
|
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
|
return {tokens, -1, logit_history, prob_history};
|
|
return {tokens, -1, logit_history, prob_history};
|
|
@@ -552,6 +553,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
const int batch_start = start + j * n_batch;
|
|
const int batch_start = start + j * n_batch;
|
|
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
|
|
|
|
|
|
|
|
+ int n_outputs = 0;
|
|
|
|
|
+
|
|
|
batch.n_tokens = 0;
|
|
batch.n_tokens = 0;
|
|
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
|
|
int seq_start = batch_start + seq*n_ctx;
|
|
int seq_start = batch_start + seq*n_ctx;
|
|
@@ -566,11 +569,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
|
|
|
|
|
for (int k = 0; k < batch_size; ++k) {
|
|
for (int k = 0; k < batch_size; ++k) {
|
|
|
const int idx = seq*n_ctx + k;
|
|
const int idx = seq*n_ctx + k;
|
|
|
- batch.token[idx] = tokens[seq_start + k];
|
|
|
|
|
- batch.pos[idx] = j*n_batch + k;
|
|
|
|
|
- batch.n_seq_id[idx] = 1;
|
|
|
|
|
- batch.seq_id[idx][0] = seq;
|
|
|
|
|
- batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
|
|
|
|
|
|
|
+ batch.token [idx] = tokens[seq_start + k];
|
|
|
|
|
+ batch.pos [idx] = j*n_batch + k;
|
|
|
|
|
+ batch.n_seq_id[idx] = 1;
|
|
|
|
|
+ batch.seq_id [idx][0] = seq;
|
|
|
|
|
+ batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
|
|
|
|
+
|
|
|
|
|
+ n_outputs += batch.logits[idx] != 0;
|
|
|
}
|
|
}
|
|
|
batch.n_tokens += batch_size;
|
|
batch.n_tokens += batch_size;
|
|
|
|
|
|
|
@@ -583,9 +588,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
return {tokens, -1, logit_history, prob_history};
|
|
return {tokens, -1, logit_history, prob_history};
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (num_batches > 1) {
|
|
|
|
|
|
|
+ if (num_batches > 1 && n_outputs > 0) {
|
|
|
const auto * batch_logits = llama_get_logits(ctx);
|
|
const auto * batch_logits = llama_get_logits(ctx);
|
|
|
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
|
|
|
|
|
|
+ logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -604,14 +609,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
|
|
- const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
|
|
|
|
|
|
|
+ const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
|
|
|
|
+
|
|
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
|
|
if (!params.logits_file.empty()) {
|
|
if (!params.logits_file.empty()) {
|
|
|
- process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
|
|
|
|
|
|
|
+ process_logits(logits_stream, n_vocab, all_logits,
|
|
|
tokens_data, n_ctx - 1 - first,
|
|
tokens_data, n_ctx - 1 - first,
|
|
|
workers, log_probs, nll, nll2);
|
|
workers, log_probs, nll, nll2);
|
|
|
} else {
|
|
} else {
|
|
|
- process_logits(n_vocab, all_logits + first*n_vocab,
|
|
|
|
|
|
|
+ process_logits(n_vocab, all_logits,
|
|
|
tokens_data, n_ctx - 1 - first,
|
|
tokens_data, n_ctx - 1 - first,
|
|
|
workers, nll, nll2,
|
|
workers, nll, nll2,
|
|
|
logit_history.data() + start + seq*n_ctx + first,
|
|
logit_history.data() + start + seq*n_ctx + first,
|
|
@@ -652,6 +658,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
|
|
|
|
+ int prev_outputs = 0;
|
|
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
|
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
|
|
|
|
|
|
@@ -672,7 +679,14 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
|
|
|
|
|
|
+ int n_outputs = 0;
|
|
|
|
|
+ for (int i = 0; i < n_tokens; ++i) {
|
|
|
|
|
+ n_outputs += batch_view.logits[i] != 0;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
|
|
|
|
+
|
|
|
|
|
+ prev_outputs += n_outputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
return true;
|
|
@@ -779,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
size_t ending_logprob_count[4];
|
|
size_t ending_logprob_count[4];
|
|
|
double ending_logprob[4];
|
|
double ending_logprob[4];
|
|
|
|
|
|
|
|
- size_t i_batch; // starting index in the llama_batch
|
|
|
|
|
|
|
+ size_t i_logits; // starting index of logits in the llama_batch
|
|
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
|
|
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
|
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
|
|
std::vector<llama_token> seq_tokens[4];
|
|
std::vector<llama_token> seq_tokens[4];
|
|
@@ -844,9 +858,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
const int max_tasks_per_batch = 32;
|
|
const int max_tasks_per_batch = 32;
|
|
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
|
|
|
|
|
|
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
|
|
|
|
|
|
+ llama_batch batch = llama_batch_init(n_ctx, 0, 4);
|
|
|
|
|
|
|
|
std::vector<float> tok_logits(n_vocab);
|
|
std::vector<float> tok_logits(n_vocab);
|
|
|
|
|
+ // TODO: this could be made smaller; it's currently the worst-case size
|
|
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
|
|
|
|
|
|
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
@@ -857,16 +872,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
int n_cur = 0;
|
|
int n_cur = 0;
|
|
|
|
|
|
|
|
size_t i1 = i0;
|
|
size_t i1 = i0;
|
|
|
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
|
|
|
|
|
|
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
|
|
|
|
|
|
|
llama_batch_clear(batch);
|
|
llama_batch_clear(batch);
|
|
|
|
|
|
|
|
// batch as much tasks as possible into the available context
|
|
// batch as much tasks as possible into the available context
|
|
|
- // each task has 4 unique seuqnce ids - one for each ending
|
|
|
|
|
|
|
+ // each task has 4 unique sequence ids - one for each ending
|
|
|
// the common prefix is shared among the 4 sequences to save tokens
|
|
// the common prefix is shared among the 4 sequences to save tokens
|
|
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
|
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
|
|
auto & hs_cur = hs_data[i1];
|
|
auto & hs_cur = hs_data[i1];
|
|
|
|
|
+ int n_logits = 0;
|
|
|
|
|
|
|
|
const int s0 = 4*(i1 - i0);
|
|
const int s0 = 4*(i1 - i0);
|
|
|
if (s0 + 4 > max_seq) {
|
|
if (s0 + 4 > max_seq) {
|
|
@@ -874,18 +890,23 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
|
|
- llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
|
|
|
|
|
|
+ llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
|
|
}
|
|
}
|
|
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
|
|
|
+ n_logits += 1;
|
|
|
|
|
|
|
|
for (int s = 0; s < 4; ++s) {
|
|
for (int s = 0; s < 4; ++s) {
|
|
|
- for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
|
|
|
|
- llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
|
|
|
|
|
|
|
+ const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
|
|
|
|
|
+ // TODO: don't evaluate the last token of each sequence
|
|
|
|
|
+ for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
|
|
|
|
+ const bool needs_logits = i < seq_tokens_size - 1;
|
|
|
|
|
+ llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
|
|
|
|
+ n_logits += needs_logits;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- hs_cur.i_batch = i_batch;
|
|
|
|
|
- i_batch += hs_cur.required_tokens;
|
|
|
|
|
|
|
+ hs_cur.i_logits = i_logits;
|
|
|
|
|
+ i_logits += n_logits;
|
|
|
|
|
|
|
|
n_cur += hs_data[i1].required_tokens;
|
|
n_cur += hs_data[i1].required_tokens;
|
|
|
if (++i1 == hs_task_count) {
|
|
if (++i1 == hs_task_count) {
|
|
@@ -911,12 +932,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
eval_pairs.clear();
|
|
eval_pairs.clear();
|
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
|
auto & hs_cur = hs_data[i];
|
|
auto & hs_cur = hs_data[i];
|
|
|
- size_t li = hs_cur.common_prefix;
|
|
|
|
|
|
|
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
|
|
for (int s = 0; s < 4; ++s) {
|
|
for (int s = 0; s < 4; ++s) {
|
|
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
|
|
- eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]);
|
|
|
|
|
|
|
+ eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
|
|
|
}
|
|
}
|
|
|
- ++li;
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
// Then we do the actual calculation
|
|
// Then we do the actual calculation
|
|
@@ -928,7 +948,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
|
auto & hs_cur = hs_data[i];
|
|
auto & hs_cur = hs_data[i];
|
|
|
|
|
|
|
|
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
|
|
|
|
|
|
+ // get the logits of the last token of the common prefix
|
|
|
|
|
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
|
|
|
|
|
|
|
|
const auto first_probs = softmax(tok_logits);
|
|
const auto first_probs = softmax(tok_logits);
|
|
|
|
|
|
|
@@ -978,7 +999,7 @@ struct winogrande_entry {
|
|
|
std::array<std::string, 2> choices;
|
|
std::array<std::string, 2> choices;
|
|
|
int answer;
|
|
int answer;
|
|
|
|
|
|
|
|
- size_t i_batch;
|
|
|
|
|
|
|
+ size_t i_logits;
|
|
|
size_t common_prefix;
|
|
size_t common_prefix;
|
|
|
size_t required_tokens;
|
|
size_t required_tokens;
|
|
|
size_t n_base1; // number of tokens for context + choice 1
|
|
size_t n_base1; // number of tokens for context + choice 1
|
|
@@ -1104,6 +1125,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
task.common_prefix++;
|
|
task.common_prefix++;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // TODO: the last token of each of the sequences don't need to be evaluated
|
|
|
task.required_tokens = task.common_prefix +
|
|
task.required_tokens = task.common_prefix +
|
|
|
task.seq_tokens[0].size() - task.common_prefix +
|
|
task.seq_tokens[0].size() - task.common_prefix +
|
|
|
task.seq_tokens[1].size() - task.common_prefix;
|
|
task.seq_tokens[1].size() - task.common_prefix;
|
|
@@ -1121,9 +1143,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
const int max_tasks_per_batch = 128;
|
|
const int max_tasks_per_batch = 128;
|
|
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
|
|
|
|
|
|
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
|
|
|
|
|
|
+ llama_batch batch = llama_batch_init(n_ctx, 0, 2);
|
|
|
|
|
|
|
|
std::vector<float> tok_logits(n_vocab);
|
|
std::vector<float> tok_logits(n_vocab);
|
|
|
|
|
+ // TODO: this could be made smaller; it's currently the worst-case size
|
|
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
|
|
|
|
|
|
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
@@ -1137,29 +1160,33 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
int n_cur = 0;
|
|
int n_cur = 0;
|
|
|
|
|
|
|
|
size_t i1 = i0;
|
|
size_t i1 = i0;
|
|
|
- size_t i_batch = 0;
|
|
|
|
|
|
|
+ size_t i_logits = 0;
|
|
|
|
|
|
|
|
llama_batch_clear(batch);
|
|
llama_batch_clear(batch);
|
|
|
|
|
|
|
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
|
|
|
|
+ int n_logits = 0;
|
|
|
const int s0 = 2*(i1 - i0);
|
|
const int s0 = 2*(i1 - i0);
|
|
|
if (s0 + 2 > max_seq) {
|
|
if (s0 + 2 > max_seq) {
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
|
|
- llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
|
|
|
|
|
|
|
+ llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
|
|
}
|
|
}
|
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
|
|
|
+ n_logits += 1;
|
|
|
|
|
|
|
|
for (int s = 0; s < 2; ++s) {
|
|
for (int s = 0; s < 2; ++s) {
|
|
|
|
|
+ // TODO: end before the last token, no need to predict past the end of the sequences
|
|
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
|
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
|
|
|
|
+ n_logits += 1;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- data[i1].i_batch = i_batch;
|
|
|
|
|
- i_batch += data[i1].required_tokens;
|
|
|
|
|
|
|
+ data[i1].i_logits = i_logits;
|
|
|
|
|
+ i_logits += n_logits;
|
|
|
|
|
|
|
|
n_cur += data[i1].required_tokens;
|
|
n_cur += data[i1].required_tokens;
|
|
|
if (++i1 == data.size()) {
|
|
if (++i1 == data.size()) {
|
|
@@ -1190,15 +1217,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
|
|
|
|
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
|
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
|
|
- size_t li = n_base1 - 1;
|
|
|
|
|
|
|
+ size_t li = n_base1 - task.common_prefix;
|
|
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
|
|
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]);
|
|
|
|
|
|
|
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
|
|
|
}
|
|
}
|
|
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
|
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
|
|
- li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
|
|
|
|
|
|
+ // FIXME: this uses the wrong first logits when not skipping the choice word
|
|
|
|
|
+ li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
|
|
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
|
|
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]);
|
|
|
|
|
|
|
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
|
@@ -1287,7 +1315,7 @@ struct multiple_choice_task {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// For evaluation
|
|
// For evaluation
|
|
|
- size_t i_batch; // starting index in the llama_batch
|
|
|
|
|
|
|
+ size_t i_logits; // starting index of logits in the llama_batch
|
|
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
|
|
size_t required_tokens; // needed number of tokens to evaluate all answers
|
|
size_t required_tokens; // needed number of tokens to evaluate all answers
|
|
|
std::vector<std::vector<llama_token>> seq_tokens;
|
|
std::vector<std::vector<llama_token>> seq_tokens;
|
|
@@ -1366,7 +1394,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
std::vector<uint32_t> task_pos(n_task);
|
|
std::vector<uint32_t> task_pos(n_task);
|
|
|
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
|
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
|
|
if (strstream.fail()) {
|
|
if (strstream.fail()) {
|
|
|
- printf("%s: failed to raad task positions from prompt\n", __func__);
|
|
|
|
|
|
|
+ printf("%s: failed to read task positions from prompt\n", __func__);
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1447,7 +1475,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
- int n_dot = n_task/100;
|
|
|
|
|
|
|
+ int n_dot = std::max((int) n_task/100, 1);
|
|
|
int i_task = 0;
|
|
int i_task = 0;
|
|
|
for (auto& task : tasks) {
|
|
for (auto& task : tasks) {
|
|
|
++i_task;
|
|
++i_task;
|
|
@@ -1491,17 +1519,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
int n_cur = 0;
|
|
int n_cur = 0;
|
|
|
|
|
|
|
|
size_t i1 = i0;
|
|
size_t i1 = i0;
|
|
|
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
|
|
|
|
|
|
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
|
|
|
|
|
|
|
llama_batch_clear(batch);
|
|
llama_batch_clear(batch);
|
|
|
|
|
|
|
|
// batch as much tasks as possible into the available context
|
|
// batch as much tasks as possible into the available context
|
|
|
- // each task has 4 unique seuqnce ids - one for each ending
|
|
|
|
|
|
|
+ // each task has 4 unique sequence ids - one for each ending
|
|
|
// the common prefix is shared among the 4 sequences to save tokens
|
|
// the common prefix is shared among the 4 sequences to save tokens
|
|
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
|
|
int s0 = 0;
|
|
int s0 = 0;
|
|
|
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
|
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
|
|
auto& cur_task = tasks[i1];
|
|
auto& cur_task = tasks[i1];
|
|
|
|
|
+ int n_logits = 0;
|
|
|
|
|
|
|
|
int num_answers = cur_task.seq_tokens.size();
|
|
int num_answers = cur_task.seq_tokens.size();
|
|
|
if (s0 + num_answers > max_seq) {
|
|
if (s0 + num_answers > max_seq) {
|
|
@@ -1518,17 +1547,22 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
|
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
|
|
}
|
|
}
|
|
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
|
|
|
+ n_logits += 1;
|
|
|
|
|
|
|
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
|
|
- for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
|
|
|
|
|
- llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
|
|
|
|
|
|
|
+ const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
|
|
|
|
|
+ // TODO: don't evaluate the last token of each sequence
|
|
|
|
|
+ for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
|
|
|
|
+ const bool needs_logits = i < seq_tokens_size - 1;
|
|
|
|
|
+ llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
|
|
|
|
+ n_logits += needs_logits;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
s0 += num_answers;
|
|
s0 += num_answers;
|
|
|
|
|
|
|
|
- cur_task.i_batch = i_batch;
|
|
|
|
|
- i_batch += cur_task.required_tokens;
|
|
|
|
|
|
|
+ cur_task.i_logits = i_logits;
|
|
|
|
|
+ i_logits += n_logits;
|
|
|
|
|
|
|
|
n_cur += cur_task.required_tokens;
|
|
n_cur += cur_task.required_tokens;
|
|
|
if (++i1 == tasks.size()) {
|
|
if (++i1 == tasks.size()) {
|
|
@@ -1554,12 +1588,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
eval_pairs.clear();
|
|
eval_pairs.clear();
|
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
for (size_t i = i0; i < i1; ++i) {
|
|
|
auto& cur_task = tasks[i];
|
|
auto& cur_task = tasks[i];
|
|
|
- size_t li = cur_task.common_prefix;
|
|
|
|
|
|
|
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
|
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
|
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
|
|
- eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]);
|
|
|
|
|
|
|
+ eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
|
|
|
}
|
|
}
|
|
|
- ++li;
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
// Then we do the actual calculation
|
|
// Then we do the actual calculation
|
|
@@ -1578,7 +1611,8 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
//}
|
|
//}
|
|
|
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
|
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
|
|
|
|
|
|
|
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
|
|
|
|
|
|
|
+ // get the logits of the last token of the common prefix
|
|
|
|
|
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
|
|
|
|
|
|
|
|
const auto first_probs = softmax(tok_logits);
|
|
const auto first_probs = softmax(tok_logits);
|
|
|
|
|
|
|
@@ -1730,6 +1764,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
|
|
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
|
return;
|
|
return;
|