|
|
@@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
t_compute_start_us = ggml_time_us();
|
|
|
}
|
|
|
|
|
|
+ // TODO: this clear of the buffer can easily be forgotten - need something better
|
|
|
embd_seq.clear();
|
|
|
|
|
|
n_queued_tokens += n_tokens;
|
|
|
@@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // this indicates we are doing pooled embedding
|
|
|
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
|
+
|
|
|
+ int64_t n_outputs_all = 0;
|
|
|
+
|
|
|
+ // count outputs
|
|
|
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
|
+ n_outputs_all += batch.logits[i] != 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (embd_pooled) {
|
|
|
+ // require that all tokens are output
|
|
|
+ if (n_outputs_all != n_tokens_all) {
|
|
|
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
|
|
|
+ __func__, n_outputs_all, n_tokens_all);
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
|
|
|
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
|
@@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
}
|
|
|
n_queued_tokens += n_tokens_all;
|
|
|
|
|
|
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
|
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
|
-
|
|
|
+ // TODO: this clear of the buffer can easily be forgotten - need something better
|
|
|
embd_seq.clear();
|
|
|
|
|
|
- int64_t n_outputs_all = 0;
|
|
|
-
|
|
|
- // count outputs
|
|
|
- if (batch.logits && !embd_pooled) {
|
|
|
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
|
- n_outputs_all += batch.logits[i] != 0;
|
|
|
- }
|
|
|
- } else if (embd_pooled) {
|
|
|
- n_outputs_all = n_tokens_all;
|
|
|
- } else {
|
|
|
- // keep last output only
|
|
|
- n_outputs_all = 1;
|
|
|
- }
|
|
|
-
|
|
|
bool did_optimize = false;
|
|
|
|
|
|
// handle any pending defrags/shifts
|
|
|
@@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
do {
|
|
|
const auto & ubatch = mstate->get_ubatch();
|
|
|
|
|
|
- // count the outputs in this u_batch
|
|
|
+ // count the outputs in this ubatch
|
|
|
{
|
|
|
int32_t n_outputs_new = 0;
|
|
|
|
|
|
@@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
|
|
|
|
|
|
n_queued_tokens += n_tokens_all;
|
|
|
|
|
|
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
|
+ // this indicates we are doing pooled embedding
|
|
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
|
|
|
|
embd_seq.clear();
|