|
@@ -180,11 +180,12 @@ static void llama_params_fit_impl(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- int64_t sum_total = 0;
|
|
|
|
|
- int64_t sum_projected_free = 0;
|
|
|
|
|
- int64_t min_projected_free = INT64_MAX;
|
|
|
|
|
- int64_t sum_projected_used = 0;
|
|
|
|
|
- int64_t sum_projected_ctx = 0;
|
|
|
|
|
|
|
+ int64_t sum_total = 0;
|
|
|
|
|
+ int64_t sum_projected_free = 0;
|
|
|
|
|
+ int64_t min_projected_free = INT64_MAX;
|
|
|
|
|
+ int64_t sum_projected_used = 0;
|
|
|
|
|
+ int64_t sum_projected_model = 0;
|
|
|
|
|
+ int64_t sum_projected_ctx = 0;
|
|
|
|
|
|
|
|
if (nd > 1) {
|
|
if (nd > 1) {
|
|
|
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
|
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
|
@@ -195,11 +196,12 @@ static void llama_params_fit_impl(
|
|
|
const int64_t projected_used = dmd.mb.total();
|
|
const int64_t projected_used = dmd.mb.total();
|
|
|
const int64_t projected_free = dmd.free - projected_used;
|
|
const int64_t projected_free = dmd.free - projected_used;
|
|
|
|
|
|
|
|
- sum_total += dmd.total;
|
|
|
|
|
- sum_projected_used += projected_used;
|
|
|
|
|
- sum_projected_free += projected_free;
|
|
|
|
|
- min_projected_free = std::min(min_projected_free, projected_free);
|
|
|
|
|
- sum_projected_ctx += dmd.mb.context;
|
|
|
|
|
|
|
+ sum_total += dmd.total;
|
|
|
|
|
+ sum_projected_used += projected_used;
|
|
|
|
|
+ sum_projected_free += projected_free;
|
|
|
|
|
+ min_projected_free = std::min(min_projected_free, projected_free);
|
|
|
|
|
+ sum_projected_model += dmd.mb.model;
|
|
|
|
|
+ sum_projected_ctx += dmd.mb.context;
|
|
|
|
|
|
|
|
if (nd > 1) {
|
|
if (nd > 1) {
|
|
|
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
|
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
|
@@ -234,10 +236,24 @@ static void llama_params_fit_impl(
|
|
|
if (cparams->n_ctx == 0) {
|
|
if (cparams->n_ctx == 0) {
|
|
|
if (hp_nct > n_ctx_min) {
|
|
if (hp_nct > n_ctx_min) {
|
|
|
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
|
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
|
|
- const uint32_t ctx_reduction = std::min(
|
|
|
|
|
- uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
|
|
|
|
|
|
+
|
|
|
|
|
+ int64_t memory_reduction = -global_surplus;
|
|
|
|
|
+ if (nd > 1) {
|
|
|
|
|
+ // for multiple devices we need to be more conservative in terms of how much context we think can fit:
|
|
|
|
|
+ // - for dense models only whole layers can be assigned to devices
|
|
|
|
|
+ // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
|
|
|
|
|
+ // - on average we expect a waste of 0.5 layers/tensors per device
|
|
|
|
|
+ // - use slightly more than the expected average for nd devices to be safe
|
|
|
|
|
+ const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
|
|
|
|
|
+ memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
|
|
cparams->n_ctx = hp_nct - ctx_reduction;
|
|
cparams->n_ctx = hp_nct - ctx_reduction;
|
|
|
- const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
|
|
|
|
|
|
|
+ cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
|
|
|
|
|
+
|
|
|
|
|
+ ctx_reduction = hp_nct - cparams->n_ctx;
|
|
|
|
|
+ memory_reduction = ctx_reduction * bytes_per_ctx;
|
|
|
global_surplus += memory_reduction;
|
|
global_surplus += memory_reduction;
|
|
|
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
|
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
|
|
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
|
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|