1
0

fit-params.cpp 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #include "llama.h"
  2. #include "arg.h"
  3. #include "common.h"
  4. #include "log.h"
  5. #include <chrono>
  6. #include <cinttypes>
  7. #include <thread>
  8. using namespace std::chrono_literals;
  9. #if defined(_MSC_VER)
  10. #pragma warning(disable: 4244 4267) // possible loss of data
  11. #endif
  12. int main(int argc, char ** argv) {
  13. common_params params;
  14. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
  15. return 1;
  16. }
  17. common_init();
  18. llama_backend_init();
  19. llama_numa_init(params.numa);
  20. auto mparams = common_model_params_to_llama(params);
  21. auto cparams = common_context_params_to_llama(params);
  22. const bool success = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
  23. params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
  24. params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
  25. if (!success) {
  26. LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
  27. exit(1);
  28. }
  29. LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__);
  30. std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout
  31. printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers);
  32. size_t nd = llama_max_devices();
  33. while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) {
  34. nd--;
  35. }
  36. if (nd > 1) {
  37. for (size_t id = 0; id < nd; id++) {
  38. if (id == 0) {
  39. printf(" -ts ");
  40. }
  41. printf("%s%" PRIu32, id > 0 ? "," : "", uint32_t(mparams.tensor_split[id]));
  42. }
  43. }
  44. const size_t ntbo = llama_max_tensor_buft_overrides();
  45. bool any_tbo = false;
  46. for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) {
  47. if (itbo == 0) {
  48. printf(" -ot \"");
  49. }
  50. printf("%s%s=%s", itbo > 0 ? "," : "", mparams.tensor_buft_overrides[itbo].pattern, ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft));
  51. any_tbo = true;
  52. }
  53. printf("%s\n", any_tbo ? "\"" : "");
  54. return 0;
  55. }