|
|
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
p1 = std::numeric_limits<llama_pos>::max();
|
|
|
}
|
|
|
|
|
|
- // models like Mamba or RWKV can't have a state partially erased
|
|
|
+ // models like Mamba or RWKV can't have a state partially erased at the end
|
|
|
+ // of the sequence because their state isn't preserved for previous tokens
|
|
|
if (seq_id >= (int64_t) size) {
|
|
|
// could be fatal
|
|
|
return false;
|
|
|
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
int32_t & tail_id = cells[seq_id].tail;
|
|
|
if (tail_id >= 0) {
|
|
|
const auto & cell = cells[tail_id];
|
|
|
- // partial intersection is invalid
|
|
|
- if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
|
|
+ // partial intersection is invalid if it includes the final pos
|
|
|
+ if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
|
|
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
|
|
|
return false;
|
|
|
}
|