Просмотр исходного кода

batch : remove logits_all flag (#14141)

ggml-ci
Georgi Gerganov 7 месяцев назад
Родитель
Сommit
c3ee46fab4

+ 2 - 8
src/llama-batch.cpp

@@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
             ubatch.seq_id = batch->seq_id + seq.offset;
         }
     }
-    if (logits_all) {
-        for (size_t i = 0; i < length; ++i) {
-            ubatch.output[ubatch.n_tokens + i] = 1;
-            out_ids.push_back(ids[seq.offset + i]);
-        }
-    } else if (batch->logits) {
+    if (batch->logits) {
         if (ubatch.equal_seqs) {
             for (size_t i = 0; i < length; ++i) {
                 size_t id = ids[seq.offset + i];
@@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
     return ubatch;
 }
 
-llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
+llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
     GGML_ASSERT(batch.n_tokens >= 0);
     this->batch = &batch;
     this->n_embd = n_embd;
-    this->logits_all = logits_all;
 
     n_tokens = batch.n_tokens;
     ids.resize(n_tokens);

+ 1 - 3
src/llama-batch.h

@@ -39,8 +39,6 @@ struct llama_sbatch {
 
     size_t n_embd;
 
-    bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
     // sorted indices into the batch
     std::vector<int64_t> ids;
     // batch indices of the output
@@ -76,7 +74,7 @@ struct llama_sbatch {
     llama_ubatch split_seq(size_t n_ubatch);
 
     llama_sbatch() = default;
-    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
+    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
 };
 
 // temporary allocate memory for the input batch if needed

+ 3 - 3
src/llama-context.cpp

@@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     const int64_t n_embd = hparams.n_embd;
 
-    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
@@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     llama_memory_state_ptr mstate;
 
     while (true) {
-        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
         if (!mstate) {
             return -2;
         }
@@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter(
 
         int64_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;

+ 2 - 2
src/llama-kv-cache-recurrent.cpp

@@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
     std::vector<llama_ubatch> ubatches;
 

+ 1 - 2
src/llama-kv-cache-recurrent.h

@@ -32,8 +32,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 

+ 3 - 3
src/llama-kv-cache-unified-iswa.cpp

@@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
     // first try simple split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
         std::vector<llama_ubatch> ubatches;
 
@@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
 
     // if it fails, try equal split
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
         std::vector<llama_ubatch> ubatches;
 

+ 1 - 2
src/llama-kv-cache-unified-iswa.h

@@ -34,8 +34,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 

+ 2 - 3
src/llama-kv-cache-unified.cpp

@@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) {
+            bool embd_pooled) {
     GGML_UNUSED(embd_pooled);
 
     do {
-        auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
         std::vector<llama_ubatch> ubatches;
         while (sbatch.n_tokens > 0) {

+ 1 - 2
src/llama-kv-cache-unified.h

@@ -59,8 +59,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_pooled) override;
 
     llama_memory_state_ptr init_full() override;
 

+ 1 - 2
src/llama-memory.h

@@ -73,8 +73,7 @@ struct llama_memory_i {
     virtual llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) = 0;
+            bool embd_pooled) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_state_ptr init_full() = 0;