test-rope.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #include "ggml.h"
  2. #include <cmath>
  3. #include <cstdio>
  4. #include <cstdlib>
  5. #include <cassert>
  6. #include <vector>
  7. #if defined(_MSC_VER)
  8. #pragma warning(disable: 4244 4267) // possible loss of data
  9. #endif
  10. #if defined(__GNUC__)
  11. #pragma GCC diagnostic ignored "-Wdouble-promotion"
  12. #endif
  13. #define MAX_NARGS 3
  14. #undef MIN
  15. #undef MAX
  16. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  17. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  18. #define GGML_SILU_FP16
  19. //
  20. // logging
  21. //
  22. #if (GGML_DEBUG >= 1)
  23. #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
  24. #else
  25. #define GGML_PRINT_DEBUG(...)
  26. #endif
  27. #if (GGML_DEBUG >= 5)
  28. #define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
  29. #else
  30. #define GGML_PRINT_DEBUG_5(...)
  31. #endif
  32. #if (GGML_DEBUG >= 10)
  33. #define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
  34. #else
  35. #define GGML_PRINT_DEBUG_10(...)
  36. #endif
  37. #define GGML_PRINT(...) printf(__VA_ARGS__)
  38. static float frand(void) {
  39. return (float)rand()/(float)RAND_MAX;
  40. }
  41. static int irand(int n) {
  42. if (n == 0) return 0;
  43. return rand()%n;
  44. }
  45. static void get_random_dims(int64_t * dims, int ndims) {
  46. dims[0] = dims[1] = dims[2] = dims[3] = 1;
  47. for (int i = 0; i < ndims; i++) {
  48. dims[i] = 1 + irand(4);
  49. }
  50. }
  51. static struct ggml_tensor * get_random_tensor_f32(
  52. struct ggml_context * ctx0,
  53. int ndims,
  54. const int64_t ne[],
  55. float fmin,
  56. float fmax) {
  57. struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
  58. switch (ndims) {
  59. case 1:
  60. for (int i0 = 0; i0 < ne[0]; i0++) {
  61. ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
  62. }
  63. break;
  64. case 2:
  65. for (int i1 = 0; i1 < ne[1]; i1++) {
  66. for (int i0 = 0; i0 < ne[0]; i0++) {
  67. ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
  68. }
  69. }
  70. break;
  71. case 3:
  72. for (int i2 = 0; i2 < ne[2]; i2++) {
  73. for (int i1 = 0; i1 < ne[1]; i1++) {
  74. for (int i0 = 0; i0 < ne[0]; i0++) {
  75. ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
  76. }
  77. }
  78. }
  79. break;
  80. case 4:
  81. for (int i3 = 0; i3 < ne[3]; i3++) {
  82. for (int i2 = 0; i2 < ne[2]; i2++) {
  83. for (int i1 = 0; i1 < ne[1]; i1++) {
  84. for (int i0 = 0; i0 < ne[0]; i0++) {
  85. ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
  86. }
  87. }
  88. }
  89. }
  90. break;
  91. default:
  92. assert(false);
  93. };
  94. return result;
  95. }
  96. static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
  97. struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
  98. if (plan.work_size > 0) {
  99. buf.resize(plan.work_size);
  100. plan.work_data = buf.data();
  101. }
  102. ggml_graph_compute(graph, &plan);
  103. }
  104. int main(int /*argc*/, const char ** /*argv*/) {
  105. struct ggml_init_params params = {
  106. /* .mem_size = */ 128*1024*1024,
  107. /* .mem_buffer = */ NULL,
  108. /* .no_alloc = */ false,
  109. };
  110. std::vector<uint8_t> work_buffer;
  111. struct ggml_context * ctx0 = ggml_init(params);
  112. struct ggml_tensor * x;
  113. // rope f32
  114. for (int m = 0; m < 3; ++m) {
  115. const int ndims = 4;
  116. const int64_t n_rot = 128;
  117. const int64_t ne[4] = { 2*n_rot, 32, 73, 1 };
  118. const int n_past_0 = 100;
  119. const int n_past_2 = 33;
  120. struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
  121. struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
  122. struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
  123. for (int i = 0; i < ne[2]; ++i) {
  124. ((int32_t *) p0->data)[i] = n_past_0 + i;
  125. ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
  126. ((int32_t *) p2->data)[i] = n_past_2 + i;
  127. }
  128. // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
  129. const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
  130. x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
  131. // 100, 101, 102, ..., 172
  132. struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
  133. // -67, -67, -67, ..., -67
  134. struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
  135. // 33, 34, 35, ..., 105
  136. struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
  137. ggml_cgraph * gf = ggml_new_graph(ctx0);
  138. ggml_build_forward_expand(gf, r0);
  139. ggml_build_forward_expand(gf, r1);
  140. ggml_build_forward_expand(gf, r2);
  141. ggml_graph_compute_helper(work_buffer, gf, 4);
  142. // check that r1 and r2 are the same
  143. {
  144. double sum0 = 0.0f;
  145. double sum1 = 0.0f;
  146. double diff = 0.0f;
  147. const float * r1_data = (float *) r1->data;
  148. const float * r2_data = (float *) r2->data;
  149. const int n_elements = ggml_nelements(r1);
  150. for (int i = 0; i < n_elements; ++i) {
  151. sum0 += fabs(r1_data[i]);
  152. sum1 += fabs(r2_data[i]);
  153. diff += fabs(r1_data[i] - r2_data[i]);
  154. //if (fabs(r1_data[i] - r2_data[i]) > 0.0001f) {
  155. // printf("%d: %f %f\n", i, r1_data[i], r2_data[i]);
  156. // printf("diff: %f\n", fabs(r1_data[i] - r2_data[i]));
  157. //}
  158. }
  159. //for (int i = 4096; i < 4096 + 128; ++i) {
  160. // printf("%f %f\n", r1_data[i], r2_data[i]);
  161. //}
  162. printf("mode: %d\n", mode);
  163. printf("sum0: %f\n", sum0);
  164. printf("sum1: %f\n", sum1);
  165. printf("diff: %f\n", diff);
  166. printf("rel err: %f\n", diff / sum0);
  167. printf("rel err: %f\n", diff / sum1);
  168. GGML_ASSERT(diff / sum0 < 0.0001f);
  169. GGML_ASSERT(diff / sum1 < 0.0001f);
  170. }
  171. }
  172. ggml_free(ctx0);
  173. return 0;
  174. }