main.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. #include "ggml.h"
  2. #include "utils.h"
  3. #include <cassert>
  4. #include <cmath>
  5. #include <cstdio>
  6. #include <cstring>
  7. #include <fstream>
  8. #include <map>
  9. #include <string>
  10. #include <vector>
  11. // determine number of model parts based on the dimension
  12. static const std::map<int, int> LLAMA_N_PARTS = {
  13. { 4096, 1 },
  14. { 5120, 2 },
  15. { 6656, 4 },
  16. { 8192, 8 },
  17. };
  18. // default hparams (LLaMA 7B)
  19. struct llama_hparams {
  20. int32_t n_vocab = 32000;
  21. int32_t n_ctx = 512; // this is provided as user input?
  22. int32_t n_embd = 4096;
  23. int32_t n_mult = 256;
  24. int32_t n_head = 32;
  25. int32_t n_layer = 32;
  26. int32_t n_rot = 64;
  27. int32_t f16 = 1;
  28. };
  29. struct llama_layer {
  30. // normalization
  31. struct ggml_tensor * attention_norm;
  32. // attention
  33. struct ggml_tensor * wq;
  34. struct ggml_tensor * wk;
  35. struct ggml_tensor * wv;
  36. struct ggml_tensor * wo;
  37. // normalization
  38. struct ggml_tensor * ffn_norm;
  39. // ff
  40. struct ggml_tensor * w1;
  41. struct ggml_tensor * w2;
  42. struct ggml_tensor * w3;
  43. };
  44. struct llama_model {
  45. llama_hparams hparams;
  46. struct ggml_tensor * tok_embeddings;
  47. struct ggml_tensor * norm;
  48. struct ggml_tensor * output;
  49. std::vector<llama_layer> layers;
  50. // key + value memory
  51. struct ggml_tensor * memory_k;
  52. struct ggml_tensor * memory_v;
  53. //
  54. struct ggml_context * ctx;
  55. std::map<std::string, struct ggml_tensor *> tensors;
  56. };
  57. // load the model's weights from a file
  58. bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
  59. printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
  60. auto fin = std::ifstream(fname, std::ios::binary);
  61. if (!fin) {
  62. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  63. return false;
  64. }
  65. // verify magic
  66. {
  67. uint32_t magic;
  68. fin.read((char *) &magic, sizeof(magic));
  69. if (magic != 0x67676d6c) {
  70. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  71. return false;
  72. }
  73. }
  74. int n_ff = 0;
  75. int n_parts = 0;
  76. // load hparams
  77. {
  78. auto & hparams = model.hparams;
  79. fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
  80. //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
  81. fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
  82. fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
  83. fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
  84. fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
  85. fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
  86. fin.read((char *) &hparams.f16, sizeof(hparams.f16));
  87. hparams.n_ctx = n_ctx;
  88. n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
  89. n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
  90. printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
  91. printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
  92. printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
  93. printf("%s: n_mult = %d\n", __func__, hparams.n_mult);
  94. printf("%s: n_head = %d\n", __func__, hparams.n_head);
  95. printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
  96. printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
  97. printf("%s: f16 = %d\n", __func__, hparams.f16);
  98. printf("%s: n_ff = %d\n", __func__, n_ff);
  99. printf("%s: n_parts = %d\n", __func__, n_parts);
  100. }
  101. // load vocab
  102. {
  103. const int32_t n_vocab = model.hparams.n_vocab;
  104. if (n_vocab != model.hparams.n_vocab) {
  105. fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
  106. __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
  107. return false;
  108. }
  109. std::string word;
  110. for (int i = 0; i < n_vocab; i++) {
  111. uint32_t len;
  112. fin.read((char *) &len, sizeof(len));
  113. word.resize(len);
  114. fin.read((char *) word.data(), len);
  115. vocab.token_to_id[word] = i;
  116. vocab.id_to_token[i] = word;
  117. //if (i < 30000) {
  118. // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
  119. //}
  120. }
  121. }
  122. // for the big tensors, we have the option to store the data in 16-bit floats or quantized
  123. // in order to save memory and also to speed up the computation
  124. ggml_type wtype = GGML_TYPE_COUNT;
  125. switch (model.hparams.f16) {
  126. case 0: wtype = GGML_TYPE_F32; break;
  127. case 1: wtype = GGML_TYPE_F16; break;
  128. case 2: wtype = GGML_TYPE_Q4_0; break;
  129. case 3: wtype = GGML_TYPE_Q4_1; break;
  130. default:
  131. {
  132. fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
  133. __func__, fname.c_str(), model.hparams.f16);
  134. return false;
  135. }
  136. }
  137. const ggml_type wtype2 = GGML_TYPE_F32;
  138. auto & ctx = model.ctx;
  139. size_t ctx_size = 0;
  140. {
  141. const auto & hparams = model.hparams;
  142. const int n_embd = hparams.n_embd;
  143. const int n_layer = hparams.n_layer;
  144. const int n_ctx = hparams.n_ctx;
  145. const int n_vocab = hparams.n_vocab;
  146. ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // tok_embeddings
  147. ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
  148. ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // output
  149. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
  150. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
  151. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
  152. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
  153. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
  154. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
  155. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
  156. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
  157. ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
  158. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
  159. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
  160. ctx_size += (5 + 10*n_layer)*256; // object overhead
  161. printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
  162. }
  163. // create the ggml context
  164. {
  165. struct ggml_init_params params = {
  166. .mem_size = ctx_size,
  167. .mem_buffer = NULL,
  168. };
  169. model.ctx = ggml_init(params);
  170. if (!model.ctx) {
  171. fprintf(stderr, "%s: ggml_init() failed\n", __func__);
  172. return false;
  173. }
  174. }
  175. // prepare memory for the weights
  176. {
  177. const auto & hparams = model.hparams;
  178. const int n_embd = hparams.n_embd;
  179. const int n_layer = hparams.n_layer;
  180. const int n_ctx = hparams.n_ctx;
  181. const int n_vocab = hparams.n_vocab;
  182. model.layers.resize(n_layer);
  183. model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  184. model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  185. model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  186. // map by name
  187. model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
  188. model.tensors["norm.weight"] = model.norm;
  189. model.tensors["output.weight"] = model.output;
  190. for (int i = 0; i < n_layer; ++i) {
  191. auto & layer = model.layers[i];
  192. layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  193. layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  194. layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  195. layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  196. layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  197. layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  198. layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
  199. layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
  200. layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
  201. // map by name
  202. model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
  203. model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
  204. model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
  205. model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
  206. model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
  207. model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
  208. model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
  209. model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
  210. model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
  211. }
  212. }
  213. // key + value memory
  214. {
  215. const auto & hparams = model.hparams;
  216. const int n_embd = hparams.n_embd;
  217. const int n_layer = hparams.n_layer;
  218. const int n_ctx = hparams.n_ctx;
  219. const int n_mem = n_layer*n_ctx;
  220. const int n_elements = n_embd*n_mem;
  221. model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  222. model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  223. const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
  224. printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
  225. }
  226. const size_t file_offset = fin.tellg();
  227. fin.close();
  228. std::vector<uint8_t> tmp;
  229. for (int i = 0; i < n_parts; ++i) {
  230. const int part_id = i;
  231. //const int part_id = n_parts - i - 1;
  232. std::string fname_part = fname;
  233. if (i > 0) {
  234. fname_part += "." + std::to_string(i);
  235. }
  236. printf("%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
  237. fin = std::ifstream(fname_part, std::ios::binary);
  238. fin.seekg(file_offset);
  239. // load weights
  240. {
  241. int n_tensors = 0;
  242. size_t total_size = 0;
  243. printf("%s: ", __func__);
  244. while (true) {
  245. int32_t n_dims;
  246. int32_t length;
  247. int32_t ftype;
  248. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  249. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  250. fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
  251. if (fin.eof()) {
  252. break;
  253. }
  254. int32_t nelements = 1;
  255. int32_t ne[2] = { 1, 1 };
  256. for (int i = 0; i < n_dims; ++i) {
  257. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  258. nelements *= ne[i];
  259. }
  260. std::string name(length, 0);
  261. fin.read(&name[0], length);
  262. if (model.tensors.find(name.data()) == model.tensors.end()) {
  263. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
  264. return false;
  265. }
  266. // split_type = 0: split by columns
  267. // split_type = 1: split by rows
  268. int split_type = 0;
  269. // split_type = 0:
  270. // regex:
  271. // - tok_embeddings.*
  272. // - layers.*.attention.wo.weight
  273. // - layers.*.feed_forward.w2.weight
  274. // split_type = 1:
  275. // regex:
  276. // - output.*
  277. // - layers.*.attention.wq.weight
  278. // - layers.*.attention.wk.weight
  279. // - layers.*.attention.wv.weight
  280. // - layers.*.feed_forward.w1.weight
  281. // - layers.*.feed_forward.w3.weight
  282. if (name.find("tok_embeddings") != std::string::npos) {
  283. split_type = 0;
  284. } else if (name.find("layers") != std::string::npos) {
  285. if (name.find("attention.wo.weight") != std::string::npos) {
  286. split_type = 0;
  287. } else if (name.find("feed_forward.w2.weight") != std::string::npos) {
  288. split_type = 0;
  289. } else {
  290. split_type = 1;
  291. }
  292. } else if (name.find("output") != std::string::npos) {
  293. split_type = 1;
  294. }
  295. auto tensor = model.tensors[name.data()];
  296. if (n_dims == 1) {
  297. if (ggml_nelements(tensor) != nelements) {
  298. fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
  299. return false;
  300. }
  301. } else {
  302. if (ggml_nelements(tensor)/n_parts != nelements) {
  303. fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
  304. return false;
  305. }
  306. }
  307. if (n_dims == 1) {
  308. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  309. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  310. __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
  311. return false;
  312. }
  313. } else {
  314. if (split_type == 0) {
  315. if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
  316. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  317. __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
  318. return false;
  319. }
  320. } else {
  321. if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
  322. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  323. __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
  324. return false;
  325. }
  326. }
  327. }
  328. if (0) {
  329. static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
  330. printf("%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
  331. }
  332. size_t bpe = 0;
  333. switch (ftype) {
  334. case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
  335. case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
  336. case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
  337. case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
  338. default:
  339. {
  340. fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
  341. return false;
  342. }
  343. };
  344. if (n_dims == 1 || n_parts == 1) {
  345. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  346. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  347. __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
  348. return false;
  349. }
  350. if (part_id == 0) {
  351. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  352. } else {
  353. fin.seekg(ggml_nbytes(tensor), std::ios::cur);
  354. }
  355. total_size += ggml_nbytes(tensor);
  356. } else {
  357. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
  358. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  359. __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
  360. return false;
  361. }
  362. if (split_type == 0) {
  363. const int np0 = ne[0];
  364. const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
  365. assert(row_size == tensor->nb[1]);
  366. for (int i1 = 0; i1 < ne[1]; ++i1) {
  367. const size_t offset_row = i1*row_size;
  368. const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
  369. fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
  370. }
  371. } else {
  372. const int np1 = ne[1];
  373. const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
  374. for (int i1 = 0; i1 < ne[1]; ++i1) {
  375. const size_t offset_row = (i1 + part_id*np1)*row_size;
  376. fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
  377. }
  378. }
  379. total_size += ggml_nbytes(tensor)/n_parts;
  380. }
  381. //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
  382. if (++n_tensors % 8 == 0) {
  383. printf(".");
  384. fflush(stdout);
  385. }
  386. }
  387. printf(" done\n");
  388. printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
  389. }
  390. fin.close();
  391. }
  392. return true;
  393. }
  394. // evaluate the transformer
  395. //
  396. // - model: the model
  397. // - n_threads: number of threads to use
  398. // - n_past: the context size so far
  399. // - embd_inp: the embeddings of the tokens in the context
  400. // - embd_w: the predicted logits for the next token
  401. //
  402. // The GPT-J model requires about 16MB of memory per input token.
  403. //
  404. bool llama_eval(
  405. const llama_model & model,
  406. const int n_threads,
  407. const int n_past,
  408. const std::vector<gpt_vocab::id> & embd_inp,
  409. std::vector<float> & embd_w,
  410. size_t & mem_per_token) {
  411. const int N = embd_inp.size();
  412. const auto & hparams = model.hparams;
  413. const int n_embd = hparams.n_embd;
  414. const int n_layer = hparams.n_layer;
  415. const int n_ctx = hparams.n_ctx;
  416. const int n_head = hparams.n_head;
  417. const int n_vocab = hparams.n_vocab;
  418. const int n_rot = hparams.n_embd/hparams.n_head;
  419. const int d_key = n_embd/n_head;
  420. static size_t buf_size = 512u*1024*1024;
  421. static void * buf = malloc(buf_size);
  422. if (mem_per_token > 0 && mem_per_token*N > buf_size) {
  423. const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
  424. //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
  425. // reallocate
  426. buf_size = buf_size_new;
  427. buf = realloc(buf, buf_size);
  428. if (buf == nullptr) {
  429. fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
  430. return false;
  431. }
  432. }
  433. struct ggml_init_params params = {
  434. .mem_size = buf_size,
  435. .mem_buffer = buf,
  436. };
  437. struct ggml_context * ctx0 = ggml_init(params);
  438. struct ggml_cgraph gf = { .n_threads = n_threads };
  439. struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
  440. memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
  441. struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
  442. for (int il = 0; il < n_layer; ++il) {
  443. struct ggml_tensor * inpSA = inpL;
  444. struct ggml_tensor * cur;
  445. // norm
  446. {
  447. cur = ggml_norm(ctx0, inpL);
  448. // cur = attention_norm*cur
  449. cur = ggml_mul(ctx0,
  450. ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
  451. cur);
  452. }
  453. // self-attention
  454. {
  455. struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
  456. struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
  457. struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
  458. // store key and value to memory
  459. if (N >= 1) {
  460. struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
  461. struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
  462. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
  463. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
  464. }
  465. // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
  466. struct ggml_tensor * Q =
  467. ggml_permute(ctx0,
  468. ggml_rope(ctx0,
  469. ggml_cpy(ctx0,
  470. Qcur,
  471. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
  472. n_past, n_rot, 0),
  473. 0, 2, 1, 3);
  474. // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
  475. struct ggml_tensor * K =
  476. ggml_permute(ctx0,
  477. ggml_rope(ctx0,
  478. ggml_reshape_3d(ctx0,
  479. ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
  480. n_embd/n_head, n_head, n_past + N),
  481. n_past, n_rot, 1),
  482. 0, 2, 1, 3);
  483. // K * Q
  484. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  485. // KQ_scaled = KQ / sqrt(n_embd/n_head)
  486. struct ggml_tensor * KQ_scaled =
  487. ggml_scale(ctx0,
  488. KQ,
  489. ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
  490. );
  491. // KQ_masked = mask_past(KQ_scaled)
  492. struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
  493. // KQ = soft_max(KQ_masked)
  494. struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
  495. // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
  496. struct ggml_tensor * V_trans =
  497. ggml_permute(ctx0,
  498. ggml_reshape_3d(ctx0,
  499. ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
  500. n_embd/n_head, n_head, n_past + N),
  501. 1, 2, 0, 3);
  502. // KQV = transpose(V) * KQ_soft_max
  503. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
  504. // KQV_merged = KQV.permute(0, 2, 1, 3)
  505. struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  506. // cur = KQV_merged.contiguous().view(n_embd, N)
  507. cur = ggml_cpy(ctx0,
  508. KQV_merged,
  509. ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
  510. // projection (no bias)
  511. cur = ggml_mul_mat(ctx0,
  512. model.layers[il].wo,
  513. cur);
  514. }
  515. struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
  516. // feed-forward network
  517. {
  518. // norm
  519. {
  520. cur = ggml_norm(ctx0, inpFF);
  521. // cur = ffn_norm*cur
  522. cur = ggml_mul(ctx0,
  523. ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
  524. cur);
  525. }
  526. struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
  527. model.layers[il].w3,
  528. cur);
  529. cur = ggml_mul_mat(ctx0,
  530. model.layers[il].w1,
  531. cur);
  532. // SILU activation
  533. cur = ggml_silu(ctx0, cur);
  534. cur = ggml_mul(ctx0, cur, tmp);
  535. cur = ggml_mul_mat(ctx0,
  536. model.layers[il].w2,
  537. cur);
  538. }
  539. cur = ggml_add(ctx0, cur, inpFF);
  540. // input for next layer
  541. inpL = cur;
  542. }
  543. // norm
  544. {
  545. inpL = ggml_norm(ctx0, inpL);
  546. // inpL = norm*inpL
  547. inpL = ggml_mul(ctx0,
  548. ggml_repeat(ctx0, model.norm, inpL),
  549. inpL);
  550. }
  551. // lm_head
  552. {
  553. inpL = ggml_mul_mat(ctx0, model.output, inpL);
  554. }
  555. // logits -> probs
  556. //inpL = ggml_soft_max(ctx0, inpL);
  557. // run the computation
  558. ggml_build_forward_expand(&gf, inpL);
  559. ggml_graph_compute (ctx0, &gf);
  560. //if (n_past%100 == 0) {
  561. // ggml_graph_print (&gf);
  562. // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
  563. //}
  564. //embd_w.resize(n_vocab*N);
  565. //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
  566. // return result for just the last token
  567. embd_w.resize(n_vocab);
  568. memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
  569. if (mem_per_token == 0) {
  570. mem_per_token = ggml_used_mem(ctx0)/N;
  571. }
  572. //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
  573. ggml_free(ctx0);
  574. return true;
  575. }
  576. int main(int argc, char ** argv) {
  577. const int64_t t_main_start_us = ggml_time_us();
  578. gpt_params params;
  579. params.model = "models/llama-7B/ggml-model.bin";
  580. if (gpt_params_parse(argc, argv, params) == false) {
  581. return 1;
  582. }
  583. if (params.seed < 0) {
  584. params.seed = time(NULL);
  585. }
  586. printf("%s: seed = %d\n", __func__, params.seed);
  587. std::mt19937 rng(params.seed);
  588. if (params.prompt.empty()) {
  589. params.prompt = gpt_random_prompt(rng);
  590. }
  591. // params.prompt = R"(// this function checks if the number n is prime
  592. //bool is_prime(int n) {)";
  593. int64_t t_load_us = 0;
  594. gpt_vocab vocab;
  595. llama_model model;
  596. // load the model
  597. {
  598. const int64_t t_start_us = ggml_time_us();
  599. if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
  600. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  601. return 1;
  602. }
  603. t_load_us = ggml_time_us() - t_start_us;
  604. }
  605. int n_past = 0;
  606. int64_t t_sample_us = 0;
  607. int64_t t_predict_us = 0;
  608. std::vector<float> logits;
  609. // tokenize the prompt
  610. std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
  611. params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
  612. printf("\n");
  613. printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
  614. printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
  615. for (int i = 0; i < (int) embd_inp.size(); i++) {
  616. printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
  617. }
  618. printf("\n");
  619. printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
  620. printf("\n\n");
  621. std::vector<gpt_vocab::id> embd;
  622. // determine the required inference memory per token:
  623. size_t mem_per_token = 0;
  624. llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
  625. int last_n_size = params.repeat_last_n;
  626. std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
  627. std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
  628. for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
  629. // predict
  630. if (embd.size() > 0) {
  631. const int64_t t_start_us = ggml_time_us();
  632. if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
  633. printf("Failed to predict\n");
  634. return 1;
  635. }
  636. t_predict_us += ggml_time_us() - t_start_us;
  637. }
  638. n_past += embd.size();
  639. embd.clear();
  640. if (i >= embd_inp.size()) {
  641. // sample next token
  642. const float top_p = params.top_p;
  643. const float temp = params.temp;
  644. const float repeat_penalty = params.repeat_penalty;
  645. const int n_vocab = model.hparams.n_vocab;
  646. gpt_vocab::id id = 0;
  647. {
  648. const int64_t t_start_sample_us = ggml_time_us();
  649. id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
  650. last_n_tokens.erase(last_n_tokens.begin());
  651. last_n_tokens.push_back(id);
  652. t_sample_us += ggml_time_us() - t_start_sample_us;
  653. }
  654. // add it to the context
  655. embd.push_back(id);
  656. } else {
  657. // if here, it means we are still processing the input prompt
  658. for (int k = i; k < embd_inp.size(); k++) {
  659. embd.push_back(embd_inp[k]);
  660. last_n_tokens.erase(last_n_tokens.begin());
  661. last_n_tokens.push_back(embd_inp[k]);
  662. if (embd.size() > params.n_batch) {
  663. break;
  664. }
  665. }
  666. i += embd.size() - 1;
  667. }
  668. // display text
  669. for (auto id : embd) {
  670. printf("%s", vocab.id_to_token[id].c_str());
  671. }
  672. fflush(stdout);
  673. // end of text token
  674. if (embd.back() == 2) {
  675. printf(" [end of text]\n");
  676. break;
  677. }
  678. }
  679. // report timing
  680. {
  681. const int64_t t_main_end_us = ggml_time_us();
  682. printf("\n\n");
  683. printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
  684. printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
  685. printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
  686. printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
  687. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
  688. }
  689. ggml_free(model.ctx);
  690. return 0;
  691. }