|
@@ -8801,12 +8801,14 @@ static int llama_decode_impl(
|
|
|
//llama_synchronize(&lctx);
|
|
//llama_synchronize(&lctx);
|
|
|
|
|
|
|
|
// decide if we need to defrag the kv cache
|
|
// decide if we need to defrag the kv cache
|
|
|
- if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
|
|
|
|
- const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
|
|
|
|
|
|
+ if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
|
|
|
|
+ // - do not defrag small contexts (i.e. < 2048 tokens)
|
|
|
|
|
+ // - count the padding towards the number of used tokens
|
|
|
|
|
+ const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
|
|
|
|
|
|
|
|
// queue defragmentation for next llama_kv_cache_update
|
|
// queue defragmentation for next llama_kv_cache_update
|
|
|
if (fragmentation > cparams.defrag_thold) {
|
|
if (fragmentation > cparams.defrag_thold) {
|
|
|
- //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
|
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
|
|
|
|
|
|
llama_kv_cache_defrag(kv_self);
|
|
llama_kv_cache_defrag(kv_self);
|
|
|
}
|
|
}
|