Explorar o código

llama : add high-throughput mode (#14363)

* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (#14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Georgi Gerganov hai 6 meses
pai
achega
225e7a1438

+ 8 - 0
common/arg.cpp

@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.swa_full = true;
         }
     ).set_env("LLAMA_ARG_SWA_FULL"));
+    add_opt(common_arg(
+        {"--kv-unified", "-kvu"},
+        string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
+        [](common_params & params) {
+            params.kv_unified = true;
+        }
+    ).set_env("LLAMA_ARG_KV_SPLIT"));
     add_opt(common_arg(
         {"--no-context-shift"},
         string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

+ 1 - 0
common/common.cpp

@@ -1163,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.no_perf           = params.no_perf;
     cparams.op_offload        = !params.no_op_offload;
     cparams.swa_full          = params.swa_full;
+    cparams.kv_unified        = params.kv_unified;
 
     cparams.type_k = params.cache_type_k;
     cparams.type_v = params.cache_type_v;

+ 1 - 0
common/common.h

@@ -341,6 +341,7 @@ struct common_params {
     bool no_perf           = false; // disable performance metrics
     bool ctx_shift         = true;  // context shift on inifinite text generation
     bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
+    bool kv_unified        = false; // enable unified KV cache
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool use_mmap          = true;  // use mmap for faster loads

+ 1 - 1
examples/embedding/embedding.cpp

@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
     const llama_vocab * vocab = llama_model_get_vocab(model);
 
     const int n_ctx_train = llama_model_n_ctx_train(model);
-    const int n_ctx = llama_n_ctx(ctx);
+    const int n_ctx       = llama_n_ctx(ctx);
 
     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
 

+ 2 - 1
examples/parallel/parallel.cpp

@@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
         auto & client = clients[i];
         client.id = i;
         client.smpl = common_sampler_init(model, params.sampling);
+        //params.sampling.seed++;
     }
 
     std::vector<llama_token> tokens_system;
@@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
                     client.n_decoded = 0;
                     client.i_batch   = batch.n_tokens - 1;
 
-                    LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
+                    LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
 
                     g_seq_id += 1;
 

+ 35 - 19
ggml/src/ggml-cuda/fattn-common.cuh

@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int channel = kbc0 / (iter_k*iter_j);
-    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
     if (jt*ncols1 + j >= ne01) {
         return;
     }
 
-    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
         const float2 * __restrict__ VKQ_meta,
         float * __restrict__ dst,
         const int parallel_blocks) {
-    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
-    dst       +=                 D * gridDim.z*blockIdx.x;
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    const int ne01 = gridDim.x;
+    const int ne02 = gridDim.y;
+
+    const int col      = blockIdx.x;
+    const int head     = blockIdx.y;
+    const int sequence = blockIdx.z;
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
     for (int i = tid; i < 2*parallel_blocks; i += D) {
-        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
     }
 
     __syncthreads();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[tid] = VKQ_numerator / VKQ_denominator;
 }
 
 [[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
-    GGML_ASSERT(Q->ne[3] == 1);
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -853,8 +869,8 @@ void launch_fattn(
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
-        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
         nb11, nb12, nb13,
         nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
         flash_attn_combine_results<DV>

+ 22 - 18
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
     int kb0_start = kbc % iter_k;
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int channel = kbc / (iter_k*iter_j);
-        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+        const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
         const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+            (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
         const int kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int channel = kbc / (iter_k*iter_j);
-    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
+    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
     const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+        (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
     const int kb0_stop_kernel  = kb0_stop  * kb_niter;

+ 20 - 14
ggml/src/ggml-cuda/fattn-tile-f16.cu

@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum((float)kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
+            dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

+ 18 - 12
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
         __syncthreads();
     }
 
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
+            dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else

+ 13 - 10
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio);
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

+ 15 - 12
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);

+ 15 - 9
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
         const int ne13,
         const int ne31,
         const int ne32,
+        const int ne33,
         const int nb31,
         const int nb32,
+        const int nb33,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float * Q_f   = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half  * V_h   = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
     const half2 * mask2 = (const half2 *)  maskh;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same<KQ_acc_t, float>::value) {
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
             if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
+            dst[j_dst_unrolled*D + i] = dst_val;
         }
 
         if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
+        dst_meta[j_dst_unrolled] = dst_meta_val;
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
+    GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

+ 3 - 6
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
-            // TODO: support broadcast
-            // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
-            //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
@@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
+            if (op->src[3] && op->src[3]->ne[2] != 1) {
+                return false;
+            }
             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }

+ 3 - 0
include/llama.h

@@ -335,6 +335,9 @@ extern "C" {
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
                           // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
                           //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
+        bool kv_unified;  // use a unified buffer across the input sequences when computing the attention
+                          // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
+                          // ref: https://github.com/ggml-org/llama.cpp/pull/14363
     };
 
     // model quantization parameters

+ 18 - 11
src/llama-batch.cpp

@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
         const llama_vocab & vocab,
         const llama_memory_i * memory,
         uint32_t n_embd,
+        uint32_t n_seq_max,
         bool output_all) {
     clear();
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
     // validate input batch
     //
 
+    if (n_seq_max > LLAMA_MAX_SEQ) {
+        LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
+        return false;
+    }
+
     if (batch.token) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
     if (batch.seq_id) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
+                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
                     return false;
                 }
             }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
 
         // initialize the starting position for each sequence based on the positions in the memory
         llama_pos p0[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (!memory) {
                 // if no memory -> start from 0
                 p0[s] = 0;
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
     // compute stats
     //
 
-    this->n_embd = n_embd;
+    this->n_embd    = n_embd;
+    this->n_seq_max = n_seq_max;
 
     // count the outputs in this batch
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
             seq_set_map[cur].push_back(i);
         }
 
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (seq_set_unq.test(s)) {
                 seq_idx[s] = seq_id_unq.size();
                 seq_id_unq.push_back(s);
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
     // consistency checks
     //
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_pos[s].empty()) {
             continue;
         }
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
     }
 
     if (memory) {
-        for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
-            for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
+        for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
+            for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
                 if (seq_cpl[s0][s1]) {
                     if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
                         memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
     //
     {
         seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_set[s].set();
         }
 
         llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_pos[s] = -1;
         }
 
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
     }
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
             ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
             ubatch.seq_id_unq.push_back(s);

+ 2 - 0
src/llama-batch.h

@@ -48,6 +48,7 @@ public:
             const llama_vocab & vocab,
             const llama_memory_i * memory,
             uint32_t n_embd,
+            uint32_t n_seq_max,
             bool output_all);
 
     const llama_batch & get_batch() const;
@@ -100,6 +101,7 @@ private:
     const uint32_t n_pos_per_embd;
 
     uint32_t n_embd;
+    uint32_t n_seq_max;
     uint32_t n_outputs;
 
     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

+ 25 - 7
src/llama-context.cpp

@@ -98,10 +98,20 @@ llama_context::llama_context(
         LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
         cparams.n_batch = GGML_KQ_MASK_PAD;
     }
-
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
     cparams.op_offload = params.op_offload;
+    cparams.kv_unified = params.kv_unified;
+
+    {
+        const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
+        const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+
+        if (!supports_set_rows && !cparams.kv_unified) {
+            LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
+            cparams.kv_unified = true;
+        }
+    }
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
@@ -112,6 +122,7 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
     LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: kv_unified    = %s\n",   __func__, cparams.kv_unified ? "true" : "false");
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
 
@@ -267,7 +278,7 @@ llama_context::llama_context(
 
     // reserve worst-case graph
     if (!hparams.vocab_only && memory) {
-        const uint32_t n_seqs = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -300,7 +311,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, mctx.get());
+            auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -311,6 +322,10 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
+            // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
+            //
+            // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
+            //
             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
@@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
             throw std::runtime_error("failed to initialize memory context");
         }
 
-        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
     const int32_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
+    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
 
     const uint32_t n_tokens = balloc->get_n_tokens();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
     const llama_ubatch ubatch = balloc->split_simple(n_tokens);
 
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // when computing embeddings, all tokens are output
     const bool output_all = cparams.embeddings;
 
-    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
+    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
             batch.logits  [pos_batch]    = true;
         }
 
-        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
+        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
             LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
             return;
         }
@@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
+        /*.kv_unified                  =*/ false,
     };
 
     return result;

+ 3 - 2
src/llama-cparams.h

@@ -11,8 +11,8 @@ struct llama_cparams {
     uint32_t n_batch;
     uint32_t n_ubatch;
     uint32_t n_seq_max;
-    int      n_threads;       // number of threads to use for generation
-    int      n_threads_batch; // number of threads to use for batch processing
+    int32_t  n_threads;       // number of threads to use for generation
+    int32_t  n_threads_batch; // number of threads to use for batch processing
 
     float rope_freq_base;
     float rope_freq_scale;
@@ -33,6 +33,7 @@ struct llama_cparams {
     bool no_perf;
     bool warmup;
     bool op_offload;
+    bool kv_unified;
 
     enum llama_pooling_type pooling_type;
 

+ 20 - 9
src/llama-graph.cpp

@@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
              float     kq_scale) const {
     const bool v_trans = v->nb[1] > v->nb[2];
 
+    // split the batch into streams if needed
+    const auto n_stream = k->ne[3];
+
+    q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
+
     q = ggml_permute(ctx0, q, 0, 2, 1, 3);
     k = ggml_permute(ctx0, k, 0, 2, 1, 3);
     v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
-    const auto n_tokens = q->ne[1];
-    const auto n_head   = q->ne[2];
-    const auto n_kv     = k->ne[1];
+    const auto n_kv = k->ne[1];
 
     ggml_tensor * cur;
 
@@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 #endif
         }
 
-        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
 
@@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
-        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        // recombine streams
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
 
         if (!cparams.offload_kqv) {
             // all nodes between the KV store and the attention output are run on the CPU
@@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
+    assert(ubatch.equal_seqs == false);
+
     ggml_tensor * q = q_cur;
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
@@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = mctx_cur->get_n_kv();
+        const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
+        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
     {
         const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask_swa);
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

+ 9 - 9
src/llama-graph.h

@@ -255,10 +255,10 @@ public:
     ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 
     ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -289,14 +289,14 @@ public:
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
     ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
     ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;

+ 40 - 0
src/llama-hparams.cpp

@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
     return n_embd_head_v * n_head_kv;
 }
 
+bool llama_hparams::is_n_embd_k_gqa_variable() const {
+    const uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_k_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+bool llama_hparams::is_n_embd_v_gqa_variable() const {
+    const uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_v_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+uint32_t llama_hparams::n_embd_k_gqa_max() const {
+    uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_k_gqa(il));
+    }
+
+    return val;
+}
+
+uint32_t llama_hparams::n_embd_v_gqa_max() const {
+    uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_v_gqa(il));
+    }
+
+    return val;
+}
+
 uint32_t llama_hparams::n_embd_r() const {
     if (wkv_head_size != 0) {
         // for RWKV models

+ 8 - 0
src/llama-hparams.h

@@ -191,6 +191,14 @@ struct llama_hparams {
     // dimension of value embeddings across all k-v heads
     uint32_t n_embd_v_gqa(uint32_t il = 0) const;
 
+    // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
+    bool is_n_embd_k_gqa_variable() const;
+    bool is_n_embd_v_gqa_variable() const;
+
+    // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
+    uint32_t n_embd_k_gqa_max() const;
+    uint32_t n_embd_v_gqa_max() const;
+
     // dimension of the rolling state embeddings
     // corresponds to Mamba's conv_states size or RWKV's token_shift states size
     uint32_t n_embd_r() const;

+ 11 - 5
src/llama-kv-cache-unified-iswa.cpp

@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                      bool   v_trans,
                      bool   offload,
                      bool   swa_full,
+                     bool   unified,
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
                  uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams) {
+                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
     llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
     llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
 
     kv_base = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_base), type_k, type_v,
-            v_trans, offload, size_base, n_seq_max, n_pad,
+            v_trans, offload, unified, size_base, n_seq_max, n_pad,
             0, LLAMA_SWA_TYPE_NONE);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_swa), type_k, type_v,
-            v_trans, offload, size_swa, n_seq_max, n_pad,
+            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
             hparams.n_swa, hparams.swa_type);
 }
 
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
     // first try simple split
     do {
+        if (!unified) {
+            // requires equal splits, so we skip the simple split
+            break;
+        }
+
         balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch, false);
+            auto ubatch = balloc.split_equal(n_ubatch, !unified);
 
             if (ubatch.n_tokens == 0) {
                 break;

+ 3 - 0
src/llama-kv-cache-unified-iswa.h

@@ -20,6 +20,7 @@ public:
                          bool   v_trans,
                          bool   offload,
                          bool   swa_full,
+                         bool   unified,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
                      uint32_t   n_ubatch,
@@ -68,6 +69,8 @@ public:
 private:
     const llama_hparams & hparams;
 
+    const bool unified;
+
     std::unique_ptr<llama_kv_cache_unified> kv_base;
     std::unique_ptr<llama_kv_cache_unified> kv_swa;
 };

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 485 - 158
src/llama-kv-cache-unified.cpp


+ 76 - 22
src/llama-kv-cache-unified.h

@@ -35,16 +35,50 @@ public:
         std::vector<uint32_t> ids;
     };
 
+    struct stream_copy_info {
+        bool empty() const {
+            assert(ssrc.size() == sdst.size());
+            return ssrc.empty();
+        }
+
+        std::vector<uint32_t> ssrc;
+        std::vector<uint32_t> sdst;
+    };
+
     // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
     //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
     struct slot_info {
         // data for ggml_set_rows
         using idx_vec_t = std::vector<uint32_t>;
 
-        idx_vec_t idxs;
+        // number of streams: ns = s1 - s0 + 1
+        llama_seq_id s0;
+        llama_seq_id s1;
+
+        std::vector<llama_seq_id> strm; // [ns]
+        std::vector<idx_vec_t>    idxs; // [ns]
 
         uint32_t head() const {
-            return idxs.at(0);
+            GGML_ASSERT(idxs.size() == 1);
+            GGML_ASSERT(!idxs[0].empty());
+
+            return idxs[0][0];
+        }
+
+        void resize(size_t n) {
+            strm.resize(n);
+            idxs.resize(n);
+        }
+
+        size_t size() const {
+            GGML_ASSERT(idxs.size() == strm.size());
+            GGML_ASSERT(!idxs.empty());
+
+            return idxs[0].size();
+        }
+
+        size_t n_stream() const {
+            return strm.size();
         }
 
         bool empty() const {
@@ -54,9 +88,6 @@ public:
         void clear() {
             idxs.clear();
         }
-
-        // TODO: implement
-        //std::vector<idx_vec_t> seq_idxs;
     };
 
     using slot_info_vec_t = std::vector<slot_info>;
@@ -68,6 +99,7 @@ public:
                     ggml_type    type_v,
                          bool    v_trans,
                          bool    offload,
+                         bool    unified,
                      uint32_t    kv_size,
                      uint32_t    n_seq_max,
                      uint32_t    n_pad,
@@ -111,7 +143,8 @@ public:
     // llama_kv_cache_unified specific API
     //
 
-    uint32_t get_size() const;
+    uint32_t get_size()     const;
+    uint32_t get_n_stream() const;
 
     bool get_has_shift() const;
 
@@ -122,8 +155,8 @@ public:
     uint32_t get_n_kv() const;
 
     // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
     ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@@ -137,7 +170,7 @@ public:
     // return empty vector on failure
     slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
 
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
 
     // find a slot of kv cells that can hold the ubatch
     // if cont == true, then the slot must be continuous
@@ -157,8 +190,9 @@ public:
     void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
     void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
 
+    void set_input_k_shift(ggml_tensor * dst) const;
+
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_k_shift   (ggml_tensor * dst) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
@@ -172,15 +206,15 @@ private:
 
         ggml_tensor * k;
         ggml_tensor * v;
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
     };
 
     bool v_trans = true;  // the value tensor is transposed
 
-    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
-    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
-    uint32_t head = 0;
-
     const uint32_t n_seq_max = 1;
+    const uint32_t n_stream  = 1;
 
     // required padding
     const uint32_t n_pad = 1;
@@ -200,7 +234,17 @@ private:
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    llama_kv_cells_unified cells;
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    std::vector<uint32_t> v_heads;
+
+    std::vector<llama_kv_cells_unified> v_cells;
+
+    // maps from a sequence id to a stream id
+    std::vector<uint32_t> seq_to_stream;
+
+    // pending stream copies that will be applied during the next update
+    stream_copy_info sc_info;
 
     std::vector<kv_layer> layers;
 
@@ -237,18 +281,25 @@ private:
                     ggml_cgraph * gf,
               const defrag_info & dinfo) const;
 
-    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+    struct cell_ranges_t {
+        uint32_t strm;
 
-    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
+    };
+
+    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
 };
 
 class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info     = llama_kv_cache_unified::defrag_info;
+    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
+    using defrag_info      = llama_kv_cache_unified::defrag_info;
+    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
 
     // used for errors
     llama_kv_cache_unified_context(llama_memory_status status);
@@ -262,7 +313,8 @@ public:
             llama_kv_cache_unified * kv,
             llama_context * lctx,
             bool do_shift,
-            defrag_info dinfo);
+            defrag_info dinfo,
+            stream_copy_info sc_info);
 
     // used to create a batch procesing context from a batch
     llama_kv_cache_unified_context(
@@ -320,6 +372,8 @@ private:
 
     defrag_info dinfo;
 
+    stream_copy_info sc_info;
+
     //
     // batch processing context
     //

+ 1 - 0
src/llama-memory-hybrid.cpp

@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         offload,
         kv_size,
         n_seq_max,
+        1,
         n_pad,
         n_swa,
         swa_type

+ 16 - 3
src/llama-model.cpp

@@ -16647,7 +16647,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 } else {
                     const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
-                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+                    uint32_t n_ctx_per_stream = cparams.n_ctx;
+
+                    if (!cparams.kv_unified) {
+                        n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
+                    } else {
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream;
+                    }
 
                     LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
@@ -16661,7 +16672,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
                                 params.swa_full,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
                                 padding);
@@ -16675,7 +16687,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 params.type_v,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,

+ 1 - 1
tests/test-backend-ops.cpp

@@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case {
 
         ggml_tensor * m = nullptr;
         if (mask) {
-            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
             ggml_set_name(m, "m");
         }
 

+ 1 - 2
tools/batched-bench/batched-bench.cpp

@@ -127,10 +127,9 @@ int main(int argc, char ** argv) {
 
                 for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
                     for (int i = 0; i < pp; ++i) {
-                        common_batch_add(batch, 0, i, { j }, false);
+                        common_batch_add(batch, 0, i, { j }, i == pp - 1);
                     }
                 }
-                batch.logits[batch.n_tokens - 1] = true;
 
                 const auto t_pp_start = ggml_time_us();
 

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio