metal.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Evaluate a statically exported ggml computation graph with Metal
  2. //
  3. // - First, export a LLaMA graph:
  4. //
  5. // $ ./bin/main -m ../models/7B/ggml-model-q4_0.bin --export
  6. //
  7. // - Run this tool to evaluate the exported graph:
  8. //
  9. // $ ./bin/metal llama.ggml
  10. //
  11. // The purpose of this tool is mostly for debugging and demonstration purposes.
  12. // The main limitation of exporting computation graphs is that their sizes are static which often
  13. // can be a problem for real-world applications.
  14. //
  15. #include "ggml.h"
  16. #include "ggml-metal.h"
  17. #include <cstdio>
  18. #include <cstring>
  19. #include <cstdlib>
  20. int main(int argc, char ** argv) {
  21. ggml_time_init();
  22. if (argc != 2) {
  23. fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
  24. return -1;
  25. }
  26. const char * fname_cgraph = argv[1];
  27. // load the compute graph
  28. struct ggml_context * ctx_data = NULL;
  29. struct ggml_context * ctx_eval = NULL;
  30. struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
  31. gf.n_threads = 1;
  32. // this allocates all Metal resources and memory buffers
  33. auto * ctx_metal = ggml_metal_init();
  34. ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data));
  35. ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval));
  36. // main
  37. {
  38. struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
  39. *(int32_t *) input->data = 1; // BOS
  40. ggml_metal_set_tensor(ctx_metal, input);
  41. // warmup
  42. ggml_metal_graph_compute(ctx_metal, &gf);
  43. const int n_iter = 16;
  44. const int64_t t0 = ggml_time_us();
  45. // the actual inference happens here
  46. for (int i = 0; i < n_iter; ++i) {
  47. ggml_metal_graph_compute(ctx_metal, &gf);
  48. }
  49. const int64_t t1 = ggml_time_us();
  50. printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
  51. }
  52. // debug output
  53. {
  54. struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
  55. ggml_metal_get_tensor(ctx_metal, logits);
  56. float * ptr = (float *) ggml_get_data(logits);
  57. printf("logits: ");
  58. for (int i = 0; i < 10; i++) {
  59. printf("%8.4f ", ptr[i]);
  60. }
  61. printf("\n");
  62. int imax = 0;
  63. double sum = 0.0;
  64. double vmax = -1e9;
  65. for (int i = 0; i < 32000; i++) {
  66. sum += (double) ptr[i];
  67. if (ptr[i] > vmax) {
  68. vmax = ptr[i];
  69. imax = i;
  70. }
  71. }
  72. printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
  73. }
  74. ggml_metal_free(ctx_metal);
  75. ggml_free(ctx_data);
  76. ggml_free(ctx_eval);
  77. return 0;
  78. }