|
|
@@ -10,13 +10,7 @@
|
|
|
|
|
|
#include "ggml.h"
|
|
|
|
|
|
-#if !defined(GGML_USE_CUBLAS)
|
|
|
-# include "ggml-alloc.h"
|
|
|
-# define LLAMA_USE_ALLOCATOR
|
|
|
-#else
|
|
|
-# define LLAMA_USE_SCRATCH
|
|
|
-# define LLAMA_MAX_SCRATCH_BUFFERS 16
|
|
|
-#endif
|
|
|
+#include "ggml-alloc.h"
|
|
|
|
|
|
#ifdef GGML_USE_CUBLAS
|
|
|
# include "ggml-cuda.h"
|
|
|
@@ -588,14 +582,6 @@ struct llama_state {
|
|
|
|
|
|
static llama_state g_state;
|
|
|
|
|
|
-//
|
|
|
-// memory sizes (calculated for n_batch == 512)
|
|
|
-//
|
|
|
-
|
|
|
-// computed for n_ctx == 2048
|
|
|
-// TODO: dynamically determine these sizes
|
|
|
-// needs modifications in ggml
|
|
|
-
|
|
|
// available llama models
|
|
|
enum e_model {
|
|
|
MODEL_UNKNOWN,
|
|
|
@@ -610,76 +596,6 @@ enum e_model {
|
|
|
static const size_t kB = 1024;
|
|
|
static const size_t MB = 1024*1024;
|
|
|
|
|
|
-static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx)
|
|
|
-{
|
|
|
- std::map<e_model, size_t> k_sizes = {
|
|
|
- { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB },
|
|
|
- { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
|
|
|
- { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
|
|
|
- { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
|
|
- { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
|
|
|
- { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
|
|
|
- };
|
|
|
- return k_sizes;
|
|
|
-}
|
|
|
-
|
|
|
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
|
|
-{
|
|
|
- static std::map<e_model, size_t> k_sizes = {
|
|
|
- { MODEL_3B, 128ull * MB },
|
|
|
- { MODEL_7B, 160ull * MB },
|
|
|
- { MODEL_13B, 192ull * MB },
|
|
|
- { MODEL_30B, 256ull * MB },
|
|
|
- { MODEL_65B, 384ull * MB }, // guess
|
|
|
- { MODEL_70B, 304ull * MB },
|
|
|
- };
|
|
|
- return k_sizes;
|
|
|
-}
|
|
|
-
|
|
|
-// used to store the compute graph tensors + non-scratch data
|
|
|
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
|
|
|
-{
|
|
|
- static std::map<e_model, size_t> k_sizes = {
|
|
|
- { MODEL_3B, 8ull * MB },
|
|
|
- { MODEL_7B, 10ull * MB },
|
|
|
- { MODEL_13B, 12ull * MB },
|
|
|
- { MODEL_30B, 16ull * MB },
|
|
|
- { MODEL_65B, 24ull * MB }, // guess
|
|
|
- { MODEL_70B, 24ull * MB },
|
|
|
- };
|
|
|
- return k_sizes;
|
|
|
-}
|
|
|
-
|
|
|
-// amount of VRAM needed per batch size to hold temporary results
|
|
|
-// the values for 3b are not derived from testing but instead chosen conservatively
|
|
|
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
|
|
|
-{
|
|
|
- static std::map<e_model, size_t> k_sizes = {
|
|
|
- { MODEL_3B, 512ull * kB },
|
|
|
- { MODEL_7B, 512ull * kB },
|
|
|
- { MODEL_13B, 640ull * kB },
|
|
|
- { MODEL_30B, 768ull * kB },
|
|
|
- { MODEL_65B, 1280ull * kB },
|
|
|
- { MODEL_70B, 1280ull * kB },
|
|
|
- };
|
|
|
- return k_sizes;
|
|
|
-}
|
|
|
-
|
|
|
-// amount of VRAM needed per batch size and context to hold temporary results
|
|
|
-// the values for 3b are not derived from testing but instead chosen conservatively
|
|
|
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
|
|
|
-{
|
|
|
- static std::map<e_model, size_t> k_sizes = {
|
|
|
- { MODEL_3B, 128ull },
|
|
|
- { MODEL_7B, 128ull },
|
|
|
- { MODEL_13B, 160ull },
|
|
|
- { MODEL_30B, 208ull },
|
|
|
- { MODEL_65B, 256ull },
|
|
|
- { MODEL_70B, 256ull },
|
|
|
- };
|
|
|
- return k_sizes;
|
|
|
-}
|
|
|
-
|
|
|
// default hparams (LLaMA 7B)
|
|
|
struct llama_hparams {
|
|
|
uint32_t n_vocab = 32000;
|
|
|
@@ -857,11 +773,9 @@ struct llama_context {
|
|
|
ggml_metal_free(ctx_metal);
|
|
|
}
|
|
|
#endif
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
if (alloc) {
|
|
|
ggml_allocr_free(alloc);
|
|
|
}
|
|
|
-#endif
|
|
|
}
|
|
|
|
|
|
std::mt19937 rng;
|
|
|
@@ -901,17 +815,8 @@ struct llama_context {
|
|
|
// memory buffers used to evaluate the model
|
|
|
llama_buffer buf_compute;
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
llama_buffer buf_alloc;
|
|
|
ggml_allocr * alloc = NULL;
|
|
|
-#endif
|
|
|
-
|
|
|
-#ifdef LLAMA_USE_SCRATCH
|
|
|
- llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
|
|
-
|
|
|
- int buf_last = 0;
|
|
|
- size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
|
|
-#endif
|
|
|
|
|
|
#ifdef GGML_USE_METAL
|
|
|
ggml_metal_context * ctx_metal = NULL;
|
|
|
@@ -920,37 +825,6 @@ struct llama_context {
|
|
|
#ifdef GGML_USE_MPI
|
|
|
ggml_mpi_context * ctx_mpi = NULL;
|
|
|
#endif
|
|
|
-
|
|
|
- void use_buf(struct ggml_context * ctx, int i) { // NOLINT
|
|
|
-#if defined(LLAMA_USE_SCRATCH)
|
|
|
- size_t last_size = 0;
|
|
|
-
|
|
|
- if (i == -1) {
|
|
|
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
|
- } else {
|
|
|
- auto & buf = buf_scratch[i];
|
|
|
- last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, });
|
|
|
- }
|
|
|
-
|
|
|
- if (buf_last >= 0) {
|
|
|
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
|
- }
|
|
|
-
|
|
|
- buf_last = i;
|
|
|
-#else
|
|
|
- (void) i;
|
|
|
- (void) ctx;
|
|
|
-#endif
|
|
|
- }
|
|
|
-
|
|
|
- size_t get_buf_max_mem(int i) { // NOLINT
|
|
|
-#if defined(LLAMA_USE_SCRATCH)
|
|
|
- return buf_max_size[i];
|
|
|
-#else
|
|
|
- (void) i;
|
|
|
- return 0;
|
|
|
-#endif
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
//
|
|
|
@@ -1620,7 +1494,6 @@ static void llama_model_load_internal(
|
|
|
|
|
|
// prepare memory for the weights
|
|
|
size_t vram_weights = 0;
|
|
|
- size_t vram_scratch = 0;
|
|
|
{
|
|
|
const uint32_t n_embd = hparams.n_embd;
|
|
|
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
|
|
|
@@ -1701,13 +1574,6 @@ static void llama_model_load_internal(
|
|
|
ctx_size +
|
|
|
mmapped_size - vram_weights; // weights in VRAM not in memory
|
|
|
|
|
|
-#ifndef LLAMA_USE_ALLOCATOR
|
|
|
- mem_required +=
|
|
|
- MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
|
|
|
- MEM_REQ_SCRATCH1().at(model.type) +
|
|
|
- MEM_REQ_EVAL().at(model.type);
|
|
|
-#endif
|
|
|
-
|
|
|
// this is the memory required by one llama_state
|
|
|
const size_t mem_required_state =
|
|
|
scale*hparams.kv_size();
|
|
|
@@ -1715,24 +1581,7 @@ static void llama_model_load_internal(
|
|
|
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
|
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
|
|
|
|
|
- (void) vram_scratch;
|
|
|
(void) n_batch;
|
|
|
-#ifdef GGML_USE_CUBLAS
|
|
|
- if (low_vram) {
|
|
|
- LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
|
|
- ggml_cuda_set_scratch_size(0); // disable scratch
|
|
|
- } else {
|
|
|
- const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type);
|
|
|
- const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type);
|
|
|
- vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
|
|
|
- ggml_cuda_set_scratch_size(vram_scratch);
|
|
|
- if (n_gpu_layers > 0) {
|
|
|
- LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n",
|
|
|
- __func__, vram_scratch_base / kB, vram_scratch_per_context,
|
|
|
- (vram_scratch + MB - 1) / MB); // round up
|
|
|
- }
|
|
|
- }
|
|
|
-#endif // GGML_USE_CUBLAS
|
|
|
|
|
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
|
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
|
|
@@ -1769,8 +1618,8 @@ static void llama_model_load_internal(
|
|
|
|
|
|
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
|
|
|
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
|
|
- LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n",
|
|
|
- __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
|
|
|
+ LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
|
|
|
+ __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
|
|
|
#else
|
|
|
(void) n_gpu_layers;
|
|
|
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
|
|
@@ -1875,9 +1724,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
/*.no_alloc =*/ false,
|
|
|
};
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
params.no_alloc = true;
|
|
|
-#endif
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
|
|
|
|
@@ -1889,14 +1736,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
if (tokens) {
|
|
|
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
|
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
|
|
}
|
|
|
-#else
|
|
|
- memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
|
|
-#endif
|
|
|
ggml_set_name(inp_tokens, "inp_tokens");
|
|
|
|
|
|
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
|
|
@@ -1907,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
|
|
|
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
ggml_allocr_alloc(lctx.alloc, inpL);
|
|
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
|
|
}
|
|
|
-#else
|
|
|
- memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
|
|
-#endif
|
|
|
}
|
|
|
|
|
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
|
|
@@ -1931,25 +1770,21 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
|
|
|
#ifdef GGML_USE_CUBLAS
|
|
|
if (n_gpu_layers > n_layer) {
|
|
|
- offload_func_nr = ggml_cuda_assign_buffers;
|
|
|
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
|
|
}
|
|
|
if (n_gpu_layers > n_layer + 1) {
|
|
|
- offload_func_v = ggml_cuda_assign_buffers;
|
|
|
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
|
|
}
|
|
|
if (n_gpu_layers > n_layer + 2) {
|
|
|
- offload_func_kq = ggml_cuda_assign_buffers;
|
|
|
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
|
|
}
|
|
|
#endif // GGML_USE_CUBLAS
|
|
|
|
|
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
|
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
|
}
|
|
|
-#else
|
|
|
- ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
|
-#endif
|
|
|
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
|
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
@@ -1959,14 +1794,12 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
|
|
|
#ifdef GGML_USE_CUBLAS
|
|
|
if (il >= i_gpu_start) {
|
|
|
- offload_func = ggml_cuda_assign_buffers;
|
|
|
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
|
|
|
}
|
|
|
#endif // GGML_USE_CUBLAS
|
|
|
|
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
- lctx.use_buf(ctx0, 0);
|
|
|
-
|
|
|
// norm
|
|
|
{
|
|
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
|
|
@@ -2104,8 +1937,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
ggml_set_name(cur, "result_wo");
|
|
|
}
|
|
|
|
|
|
- lctx.use_buf(ctx0, 1);
|
|
|
-
|
|
|
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
|
|
offload_func(inpFF);
|
|
|
ggml_set_name(inpFF, "inpFF");
|
|
|
@@ -2160,8 +1991,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
inpL = cur;
|
|
|
}
|
|
|
|
|
|
- lctx.use_buf(ctx0, 0);
|
|
|
-
|
|
|
// norm
|
|
|
{
|
|
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
|
|
@@ -2178,8 +2007,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
|
ggml_set_name(cur, "result_output");
|
|
|
|
|
|
- lctx.use_buf(ctx0, -1);
|
|
|
-
|
|
|
// logits -> probs
|
|
|
//cur = ggml_soft_max_inplace(ctx0, cur);
|
|
|
|
|
|
@@ -2189,15 +2016,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
mem_per_token = ggml_used_mem(ctx0)/N;
|
|
|
}
|
|
|
|
|
|
-#if 0
|
|
|
- LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
|
|
|
- ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
|
- lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
|
- lctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
|
- lctx.work_buffer.size()/1024.0/1024.0,
|
|
|
- n_past, N);
|
|
|
-#endif
|
|
|
-
|
|
|
ggml_free(ctx0);
|
|
|
|
|
|
return gf;
|
|
|
@@ -2248,14 +2066,26 @@ static bool llama_eval_internal(
|
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
const int64_t n_vocab = hparams.n_vocab;
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
ggml_allocr_reset(lctx.alloc);
|
|
|
-#endif
|
|
|
|
|
|
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
|
|
+
|
|
|
+#ifdef GGML_USE_CUBLAS
|
|
|
+ for (int i = 0; i < gf->n_leafs; i++) {
|
|
|
+ ggml_tensor * node = gf->leafs[i];
|
|
|
+ if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
|
|
+ ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = 0; i < gf->n_nodes; i++) {
|
|
|
+ ggml_tensor * node = gf->nodes[i];
|
|
|
+ if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
|
|
+ ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
|
|
+ }
|
|
|
+ }
|
|
|
#endif
|
|
|
|
|
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
|
@@ -4319,7 +4149,6 @@ struct llama_context * llama_new_context_with_model(
|
|
|
ctx->embedding.resize(hparams.n_embd);
|
|
|
}
|
|
|
|
|
|
-#ifdef LLAMA_USE_ALLOCATOR
|
|
|
{
|
|
|
static const size_t tensor_alignment = 32;
|
|
|
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
|
|
@@ -4350,13 +4179,6 @@ struct llama_context * llama_new_context_with_model(
|
|
|
|
|
|
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
|
|
|
|
|
- // debug - for comparison with scratch buffer
|
|
|
- //size_t prev_req =
|
|
|
- // MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
|
|
|
- // MEM_REQ_SCRATCH1().at(ctx->model.type) +
|
|
|
- // MEM_REQ_EVAL().at(ctx->model.type);
|
|
|
- //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
|
|
|
-
|
|
|
// recreate allocator with exact memory requirements
|
|
|
ggml_allocr_free(ctx->alloc);
|
|
|
|
|
|
@@ -4367,15 +4189,16 @@ struct llama_context * llama_new_context_with_model(
|
|
|
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
|
|
|
}
|
|
|
#endif
|
|
|
- }
|
|
|
-#else
|
|
|
- ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
|
|
-#endif
|
|
|
-
|
|
|
-#ifdef LLAMA_USE_SCRATCH
|
|
|
- ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
|
|
- ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
|
|
+#ifdef GGML_USE_CUBLAS
|
|
|
+ if (params.low_vram) {
|
|
|
+ LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
|
|
+ ggml_cuda_set_scratch_size(0); // disable scratch
|
|
|
+ } else {
|
|
|
+ ggml_cuda_set_scratch_size(alloc_size);
|
|
|
+ LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
|
|
|
+ }
|
|
|
#endif
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#ifdef GGML_USE_METAL
|