main.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. #include "ggml.h"
  2. #include "utils.h"
  3. #include <cassert>
  4. #include <cmath>
  5. #include <cstdio>
  6. #include <cstring>
  7. #include <fstream>
  8. #include <map>
  9. #include <string>
  10. #include <vector>
  11. // default hparams (LLaMA 7B)
  12. struct llama_hparams {
  13. int32_t n_vocab = 32000;
  14. int32_t n_ctx = 512; // this is provided as user input?
  15. int32_t n_embd = 4096;
  16. int32_t n_mult = 256;
  17. int32_t n_head = 32;
  18. int32_t n_layer = 32;
  19. int32_t n_rot = 64;
  20. int32_t f16 = 1;
  21. };
  22. struct llama_layer {
  23. // normalization
  24. struct ggml_tensor * attention_norm;
  25. // attention
  26. struct ggml_tensor * wq;
  27. struct ggml_tensor * wk;
  28. struct ggml_tensor * wv;
  29. struct ggml_tensor * wo;
  30. // normalization
  31. struct ggml_tensor * ffn_norm;
  32. // ff
  33. struct ggml_tensor * w1;
  34. struct ggml_tensor * w2;
  35. struct ggml_tensor * w3;
  36. };
  37. struct llama_model {
  38. llama_hparams hparams;
  39. struct ggml_tensor * tok_embeddings;
  40. struct ggml_tensor * norm;
  41. struct ggml_tensor * output;
  42. std::vector<llama_layer> layers;
  43. // key + value memory
  44. struct ggml_tensor * memory_k;
  45. struct ggml_tensor * memory_v;
  46. //
  47. struct ggml_context * ctx;
  48. std::map<std::string, struct ggml_tensor *> tensors;
  49. };
  50. // load the model's weights from a file
  51. bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
  52. printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
  53. auto fin = std::ifstream(fname, std::ios::binary);
  54. if (!fin) {
  55. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  56. return false;
  57. }
  58. // verify magic
  59. {
  60. uint32_t magic;
  61. fin.read((char *) &magic, sizeof(magic));
  62. if (magic != 0x67676d6c) {
  63. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  64. return false;
  65. }
  66. }
  67. int n_ff = 0;
  68. // load hparams
  69. {
  70. auto & hparams = model.hparams;
  71. fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
  72. //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
  73. fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
  74. fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
  75. fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
  76. fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
  77. fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
  78. fin.read((char *) &hparams.f16, sizeof(hparams.f16));
  79. hparams.n_ctx = n_ctx;
  80. n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
  81. printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
  82. printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
  83. printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
  84. printf("%s: n_mult = %d\n", __func__, hparams.n_mult);
  85. printf("%s: n_head = %d\n", __func__, hparams.n_head);
  86. printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
  87. printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
  88. printf("%s: f16 = %d\n", __func__, hparams.f16);
  89. printf("%s: n_ff = %d\n", __func__, n_ff);
  90. }
  91. // load vocab
  92. {
  93. const int32_t n_vocab = model.hparams.n_vocab;
  94. if (n_vocab != model.hparams.n_vocab) {
  95. fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
  96. __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
  97. return false;
  98. }
  99. std::string word;
  100. for (int i = 0; i < n_vocab; i++) {
  101. uint32_t len;
  102. fin.read((char *) &len, sizeof(len));
  103. word.resize(len);
  104. fin.read((char *) word.data(), len);
  105. vocab.token_to_id[word] = i;
  106. vocab.id_to_token[i] = word;
  107. //if (i < 30000) {
  108. // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
  109. //}
  110. }
  111. }
  112. // for the big tensors, we have the option to store the data in 16-bit floats or quantized
  113. // in order to save memory and also to speed up the computation
  114. ggml_type wtype = GGML_TYPE_COUNT;
  115. switch (model.hparams.f16) {
  116. case 0: wtype = GGML_TYPE_F32; break;
  117. case 1: wtype = GGML_TYPE_F16; break;
  118. case 2: wtype = GGML_TYPE_Q4_0; break;
  119. case 3: wtype = GGML_TYPE_Q4_1; break;
  120. default:
  121. {
  122. fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
  123. __func__, fname.c_str(), model.hparams.f16);
  124. return false;
  125. }
  126. }
  127. const ggml_type wtype2 = GGML_TYPE_F32;
  128. auto & ctx = model.ctx;
  129. size_t ctx_size = 0;
  130. {
  131. const auto & hparams = model.hparams;
  132. const int n_embd = hparams.n_embd;
  133. const int n_layer = hparams.n_layer;
  134. const int n_ctx = hparams.n_ctx;
  135. const int n_vocab = hparams.n_vocab;
  136. ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // tok_embeddings
  137. ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
  138. ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // output
  139. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
  140. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
  141. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
  142. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
  143. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
  144. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
  145. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
  146. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
  147. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
  148. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
  149. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
  150. ctx_size += (5 + 10*n_layer)*256; // object overhead
  151. printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
  152. }
  153. // create the ggml context
  154. {
  155. struct ggml_init_params params = {
  156. .mem_size = ctx_size,
  157. .mem_buffer = NULL,
  158. };
  159. model.ctx = ggml_init(params);
  160. if (!model.ctx) {
  161. fprintf(stderr, "%s: ggml_init() failed\n", __func__);
  162. return false;
  163. }
  164. }
  165. // prepare memory for the weights
  166. {
  167. const auto & hparams = model.hparams;
  168. const int n_embd = hparams.n_embd;
  169. const int n_layer = hparams.n_layer;
  170. const int n_ctx = hparams.n_ctx;
  171. const int n_vocab = hparams.n_vocab;
  172. model.layers.resize(n_layer);
  173. model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  174. model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  175. model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  176. // map by name
  177. model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
  178. model.tensors["norm.weight"] = model.norm;
  179. model.tensors["output.weight"] = model.output;
  180. for (int i = 0; i < n_layer; ++i) {
  181. auto & layer = model.layers[i];
  182. layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  183. layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  184. layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  185. layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  186. layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  187. layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  188. layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
  189. layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
  190. layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
  191. // map by name
  192. model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
  193. model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
  194. model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
  195. model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
  196. model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
  197. model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
  198. model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
  199. model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
  200. model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
  201. }
  202. }
  203. // key + value memory
  204. {
  205. const auto & hparams = model.hparams;
  206. const int n_embd = hparams.n_embd;
  207. const int n_layer = hparams.n_layer;
  208. const int n_ctx = hparams.n_ctx;
  209. const int n_mem = n_layer*n_ctx;
  210. const int n_elements = n_embd*n_mem;
  211. model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  212. model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  213. const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
  214. printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
  215. }
  216. // load weights
  217. {
  218. int n_tensors = 0;
  219. size_t total_size = 0;
  220. printf("%s: ", __func__);
  221. while (true) {
  222. int32_t n_dims;
  223. int32_t length;
  224. int32_t ftype;
  225. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  226. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  227. fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
  228. if (fin.eof()) {
  229. break;
  230. }
  231. int32_t nelements = 1;
  232. int32_t ne[2] = { 1, 1 };
  233. for (int i = 0; i < n_dims; ++i) {
  234. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  235. nelements *= ne[i];
  236. }
  237. std::string name(length, 0);
  238. fin.read(&name[0], length);
  239. if (model.tensors.find(name.data()) == model.tensors.end()) {
  240. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
  241. return false;
  242. }
  243. auto tensor = model.tensors[name.data()];
  244. if (ggml_nelements(tensor) != nelements) {
  245. fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
  246. return false;
  247. }
  248. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  249. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  250. __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
  251. return false;
  252. }
  253. if (0) {
  254. static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
  255. printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  256. }
  257. size_t bpe = 0;
  258. switch (ftype) {
  259. case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
  260. case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
  261. case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
  262. case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
  263. default:
  264. {
  265. fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
  266. return false;
  267. }
  268. };
  269. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  270. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  271. __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
  272. return false;
  273. }
  274. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  275. //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
  276. total_size += ggml_nbytes(tensor);
  277. if (++n_tensors % 8 == 0) {
  278. printf(".");
  279. fflush(stdout);
  280. }
  281. }
  282. printf(" done\n");
  283. printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
  284. }
  285. fin.close();
  286. return true;
  287. }
  288. // evaluate the transformer
  289. //
  290. // - model: the model
  291. // - n_threads: number of threads to use
  292. // - n_past: the context size so far
  293. // - embd_inp: the embeddings of the tokens in the context
  294. // - embd_w: the predicted logits for the next token
  295. //
  296. // The GPT-J model requires about 16MB of memory per input token.
  297. //
  298. bool llama_eval(
  299. const llama_model & model,
  300. const int n_threads,
  301. const int n_past,
  302. const std::vector<gpt_vocab::id> & embd_inp,
  303. std::vector<float> & embd_w,
  304. size_t & mem_per_token) {
  305. const int N = embd_inp.size();
  306. const auto & hparams = model.hparams;
  307. const int n_embd = hparams.n_embd;
  308. const int n_layer = hparams.n_layer;
  309. const int n_ctx = hparams.n_ctx;
  310. const int n_head = hparams.n_head;
  311. const int n_vocab = hparams.n_vocab;
  312. const int n_rot = hparams.n_rot;
  313. const int d_key = n_embd/n_head;
  314. static size_t buf_size = 256u*1024*1024;
  315. static void * buf = malloc(buf_size);
  316. if (mem_per_token > 0 && mem_per_token*N > buf_size) {
  317. const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
  318. //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
  319. // reallocate
  320. buf_size = buf_size_new;
  321. buf = realloc(buf, buf_size);
  322. if (buf == nullptr) {
  323. fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
  324. return false;
  325. }
  326. }
  327. struct ggml_init_params params = {
  328. .mem_size = buf_size,
  329. .mem_buffer = buf,
  330. };
  331. struct ggml_context * ctx0 = ggml_init(params);
  332. struct ggml_cgraph gf = { .n_threads = n_threads };
  333. struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
  334. memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
  335. struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
  336. for (int il = 0; il < n_layer; ++il) {
  337. struct ggml_tensor * inpSA = inpL;
  338. struct ggml_tensor * cur;
  339. // norm
  340. {
  341. cur = ggml_norm(ctx0, inpL);
  342. // cur = attention_norm*cur
  343. cur = ggml_mul(ctx0,
  344. ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
  345. cur);
  346. }
  347. // self-attention
  348. {
  349. struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
  350. struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
  351. struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
  352. // store key and value to memory
  353. if (N >= 1) {
  354. struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
  355. struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
  356. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
  357. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
  358. }
  359. // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
  360. struct ggml_tensor * Q =
  361. ggml_permute(ctx0,
  362. ggml_rope(ctx0,
  363. ggml_cpy(ctx0,
  364. Qcur,
  365. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
  366. n_past, n_rot, 0),
  367. 0, 2, 1, 3);
  368. // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
  369. struct ggml_tensor * K =
  370. ggml_permute(ctx0,
  371. ggml_rope(ctx0,
  372. ggml_reshape_3d(ctx0,
  373. ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
  374. n_embd/n_head, n_head, n_past + N),
  375. n_past, n_rot, 1),
  376. 0, 2, 1, 3);
  377. // K * Q
  378. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  379. // KQ_scaled = KQ / sqrt(n_embd/n_head)
  380. struct ggml_tensor * KQ_scaled =
  381. ggml_scale(ctx0,
  382. KQ,
  383. ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
  384. );
  385. // KQ_masked = mask_past(KQ_scaled)
  386. struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
  387. // KQ = soft_max(KQ_masked)
  388. struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
  389. // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
  390. struct ggml_tensor * V_trans =
  391. ggml_permute(ctx0,
  392. ggml_reshape_3d(ctx0,
  393. ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
  394. n_embd/n_head, n_head, n_past + N),
  395. 1, 2, 0, 3);
  396. // KQV = transpose(V) * KQ_soft_max
  397. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
  398. // KQV_merged = KQV.permute(0, 2, 1, 3)
  399. struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  400. // cur = KQV_merged.contiguous().view(n_embd, N)
  401. cur = ggml_cpy(ctx0,
  402. KQV_merged,
  403. ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
  404. // projection (no bias)
  405. cur = ggml_mul_mat(ctx0,
  406. model.layers[il].wo,
  407. cur);
  408. }
  409. struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
  410. // feed-forward network
  411. {
  412. // norm
  413. {
  414. cur = ggml_norm(ctx0, inpFF);
  415. // cur = ffn_norm*cur
  416. cur = ggml_mul(ctx0,
  417. ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
  418. cur);
  419. }
  420. struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
  421. model.layers[il].w3,
  422. cur);
  423. cur = ggml_mul_mat(ctx0,
  424. model.layers[il].w1,
  425. cur);
  426. // SILU activation
  427. cur = ggml_silu(ctx0, cur);
  428. cur = ggml_mul(ctx0, cur, tmp);
  429. cur = ggml_mul_mat(ctx0,
  430. model.layers[il].w2,
  431. cur);
  432. }
  433. cur = ggml_add(ctx0, cur, inpFF);
  434. // input for next layer
  435. inpL = cur;
  436. }
  437. // norm
  438. {
  439. inpL = ggml_norm(ctx0, inpL);
  440. // inpL = norm*inpL
  441. inpL = ggml_mul(ctx0,
  442. ggml_repeat(ctx0, model.norm, inpL),
  443. inpL);
  444. }
  445. // lm_head
  446. {
  447. inpL = ggml_mul_mat(ctx0, model.output, inpL);
  448. }
  449. // logits -> probs
  450. //inpL = ggml_soft_max(ctx0, inpL);
  451. // run the computation
  452. ggml_build_forward_expand(&gf, inpL);
  453. ggml_graph_compute (ctx0, &gf);
  454. //if (n_past%100 == 0) {
  455. // ggml_graph_print (&gf);
  456. // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
  457. //}
  458. //embd_w.resize(n_vocab*N);
  459. //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
  460. // return result for just the last token
  461. embd_w.resize(n_vocab);
  462. memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
  463. if (mem_per_token == 0) {
  464. mem_per_token = ggml_used_mem(ctx0)/N;
  465. }
  466. //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
  467. ggml_free(ctx0);
  468. return true;
  469. }
  470. int main(int argc, char ** argv) {
  471. const int64_t t_main_start_us = ggml_time_us();
  472. gpt_params params;
  473. params.model = "models/llama-7B/ggml-model.bin";
  474. if (gpt_params_parse(argc, argv, params) == false) {
  475. return 1;
  476. }
  477. if (params.seed < 0) {
  478. params.seed = time(NULL);
  479. }
  480. printf("%s: seed = %d\n", __func__, params.seed);
  481. std::mt19937 rng(params.seed);
  482. if (params.prompt.empty()) {
  483. params.prompt = gpt_random_prompt(rng);
  484. }
  485. int64_t t_load_us = 0;
  486. gpt_vocab vocab;
  487. llama_model model;
  488. // load the model
  489. {
  490. const int64_t t_start_us = ggml_time_us();
  491. if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
  492. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  493. return 1;
  494. }
  495. t_load_us = ggml_time_us() - t_start_us;
  496. }
  497. int n_past = 0;
  498. int64_t t_sample_us = 0;
  499. int64_t t_predict_us = 0;
  500. std::vector<float> logits;
  501. // tokenize the prompt
  502. std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
  503. params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
  504. printf("\n");
  505. printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
  506. printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
  507. for (int i = 0; i < (int) embd_inp.size(); i++) {
  508. printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
  509. }
  510. printf("\n");
  511. printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
  512. printf("\n\n");
  513. std::vector<gpt_vocab::id> embd;
  514. // determine the required inference memory per token:
  515. size_t mem_per_token = 0;
  516. llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
  517. for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
  518. // predict
  519. if (embd.size() > 0) {
  520. const int64_t t_start_us = ggml_time_us();
  521. if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
  522. printf("Failed to predict\n");
  523. return 1;
  524. }
  525. t_predict_us += ggml_time_us() - t_start_us;
  526. }
  527. n_past += embd.size();
  528. embd.clear();
  529. if (i >= embd_inp.size()) {
  530. // sample next token
  531. const int top_k = params.top_k;
  532. const float top_p = params.top_p;
  533. const float temp = params.temp;
  534. const int n_vocab = model.hparams.n_vocab;
  535. gpt_vocab::id id = 0;
  536. {
  537. const int64_t t_start_sample_us = ggml_time_us();
  538. id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
  539. t_sample_us += ggml_time_us() - t_start_sample_us;
  540. }
  541. // add it to the context
  542. embd.push_back(id);
  543. } else {
  544. // if here, it means we are still processing the input prompt
  545. for (int k = i; k < embd_inp.size(); k++) {
  546. embd.push_back(embd_inp[k]);
  547. if (embd.size() > params.n_batch) {
  548. break;
  549. }
  550. }
  551. i += embd.size() - 1;
  552. }
  553. // display text
  554. for (auto id : embd) {
  555. printf("%s", vocab.id_to_token[id].c_str());
  556. }
  557. fflush(stdout);
  558. // end of text token
  559. if (embd.back() == 2) {
  560. printf(" [end of text]\n");
  561. break;
  562. }
  563. }
  564. // report timing
  565. {
  566. const int64_t t_main_end_us = ggml_time_us();
  567. printf("\n\n");
  568. printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
  569. printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
  570. printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
  571. printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
  572. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
  573. }
  574. ggml_free(model.ctx);
  575. return 0;
  576. }