|
@@ -270,19 +270,7 @@ llama_context::llama_context(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // resolve automatic Flash Attention use and reserve worst-case graph
|
|
|
|
|
if (!hparams.vocab_only) {
|
|
if (!hparams.vocab_only) {
|
|
|
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
|
|
|
|
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
|
|
-
|
|
|
|
|
- LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
|
-
|
|
|
|
|
- int n_splits_pp = -1;
|
|
|
|
|
- int n_nodes_pp = -1;
|
|
|
|
|
-
|
|
|
|
|
- int n_splits_tg = -1;
|
|
|
|
|
- int n_nodes_tg = -1;
|
|
|
|
|
-
|
|
|
|
|
llama_memory_context_ptr mctx;
|
|
llama_memory_context_ptr mctx;
|
|
|
if (memory) {
|
|
if (memory) {
|
|
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
@@ -293,6 +281,59 @@ llama_context::llama_context(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
cross.v_embd.clear();
|
|
cross.v_embd.clear();
|
|
|
|
|
+ // resolve automatic Flash Attention use
|
|
|
|
|
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
|
|
|
|
+ auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
|
|
|
|
|
+ if (!gf) {
|
|
|
|
|
+ throw std::runtime_error("failed to split graph for Flash Attention check");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
|
|
|
|
+ bool fa_device_mismatch = false;
|
|
|
|
|
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
|
|
|
+ ggml_tensor * n = ggml_graph_node(gf, i);
|
|
|
|
|
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
|
|
|
|
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
|
|
|
+
|
|
|
|
|
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
|
|
|
|
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
|
|
|
+ const int il = std::stoi(n->name + prefix_len);
|
|
|
|
|
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
|
|
|
+ if (device_fa != device_kv) {
|
|
|
|
|
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
|
|
|
|
+ "is assigned to device %s (usually due to missing support)\n",
|
|
|
|
|
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
|
|
|
|
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
|
|
|
|
+ fa_device_mismatch = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (fa_device_mismatch) {
|
|
|
|
|
+ cparams.flash_attn = false;
|
|
|
|
|
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
|
|
|
|
+ if (ggml_is_quantized(params.type_v)) {
|
|
|
|
|
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ cparams.flash_attn = true;
|
|
|
|
|
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // reserve worst-case graph
|
|
|
|
|
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
|
|
|
|
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
|
|
+
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
|
+
|
|
|
|
|
+ int n_splits_pp = -1;
|
|
|
|
|
+ int n_nodes_pp = -1;
|
|
|
|
|
+
|
|
|
|
|
+ int n_splits_tg = -1;
|
|
|
|
|
+ int n_nodes_tg = -1;
|
|
|
|
|
|
|
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
|
|
{
|
|
{
|
|
@@ -301,48 +342,6 @@ llama_context::llama_context(
|
|
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
|
|
|
|
- ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
|
|
|
-
|
|
|
|
|
- const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
|
|
|
|
- bool fa_device_mismatch = false;
|
|
|
|
|
- for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
|
|
|
- ggml_tensor * n = ggml_graph_node(gf, i);
|
|
|
|
|
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
- ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
|
|
|
|
- ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
|
|
|
-
|
|
|
|
|
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
|
|
|
|
- GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
|
|
|
- const int il = std::stoi(n->name + prefix_len);
|
|
|
|
|
- ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
|
|
|
- if (device_fa != device_kv) {
|
|
|
|
|
- LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
|
|
|
|
- "is assigned to device %s (usually due to missing support)\n",
|
|
|
|
|
- __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
|
|
|
|
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
|
|
|
|
- fa_device_mismatch = true;
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- if (fa_device_mismatch) {
|
|
|
|
|
- cparams.flash_attn = false;
|
|
|
|
|
- LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
|
|
|
|
- if (ggml_is_quantized(params.type_v)) {
|
|
|
|
|
- throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
|
|
|
- }
|
|
|
|
|
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
|
|
|
- if (!gf) {
|
|
|
|
|
- throw std::runtime_error("failed to allocate compute pp buffers");
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- cparams.flash_attn = true;
|
|
|
|
|
- LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
|
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
|
}
|
|
}
|
|
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
|
|
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
|
|
|
|
|
|
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
|
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
|
|
|
|
if (n_tokens % n_seqs != 0) {
|
|
if (n_tokens % n_seqs != 0) {
|
|
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
this->n_outputs = save_n_outputs;
|
|
this->n_outputs = save_n_outputs;
|
|
|
|
|
|
|
|
// initialize scheduler with the specified graph
|
|
// initialize scheduler with the specified graph
|
|
|
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
|
|
|
|
|
+ if (split_only) {
|
|
|
|
|
+ ggml_backend_sched_split_graph(sched.get(), gf);
|
|
|
|
|
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
}
|
|
}
|