ggml-aarch64.c 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #define GGML_COMMON_DECL_C
  2. #include "ggml-common.h"
  3. #include "ggml-aarch64.h"
  4. #include "ggml-impl.h"
  5. #include "ggml-quants.h"
  6. #include <assert.h>
  7. #define UNUSED GGML_UNUSED
  8. static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
  9. block_q4_0x4 out;
  10. for (int i = 0; i < 4; i++) {
  11. out.d[i] = in[i].d;
  12. }
  13. const int end = QK4_0 * 2 / blck_size_interleave;
  14. if (blck_size_interleave == 8) {
  15. const uint64_t xor_mask = 0x8888888888888888ULL;
  16. for (int i = 0; i < end; ++i) {
  17. int src_id = i % 4;
  18. int src_offset = (i / 4) * blck_size_interleave;
  19. int dst_offset = i * blck_size_interleave;
  20. uint64_t elems;
  21. // Using memcpy to avoid unaligned memory accesses
  22. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  23. elems ^= xor_mask;
  24. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  25. }
  26. } else if (blck_size_interleave == 4) {
  27. const uint32_t xor_mask = 0x88888888;
  28. for (int i = 0; i < end; ++i) {
  29. int src_id = i % 4;
  30. int src_offset = (i / 4) * blck_size_interleave;
  31. int dst_offset = i * blck_size_interleave;
  32. uint32_t elems;
  33. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
  34. elems ^= xor_mask;
  35. memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
  36. }
  37. } else {
  38. GGML_ASSERT(false);
  39. }
  40. return out;
  41. }
  42. // interleave 8 block_q4_0s in blocks of blck_size_interleave
  43. // returns an interleaved block_q4_0x8
  44. // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
  45. // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
  46. static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
  47. block_q4_0x8 out;
  48. for (int i = 0; i < 8; i++) {
  49. out.d[i] = in[i].d;
  50. }
  51. const int end = QK4_0 * 4 / blck_size_interleave;
  52. const uint64_t xor_mask = 0x8888888888888888ULL;
  53. for (int i = 0; i < end; ++i) {
  54. int src_id = i % 8;
  55. int src_offset = (i / 8) * blck_size_interleave;
  56. int dst_offset = i * blck_size_interleave;
  57. uint64_t elems;
  58. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  59. elems ^= xor_mask;
  60. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  61. }
  62. return out;
  63. }
  64. static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
  65. assert(n_per_row % QK4_0 == 0);
  66. const int nb = n_per_row / QK4_0;
  67. void * out_ptr = NULL;
  68. if (nrows_interleaved == 8) {
  69. out_ptr = (block_q4_0x8 *) dst;
  70. }
  71. else if (nrows_interleaved == 4) {
  72. out_ptr = (block_q4_0x4 *) dst;
  73. }
  74. assert(nrows_interleaved <= 8);
  75. block_q4_0 dst_tmp[8];
  76. for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
  77. for (int64_t x = 0; x < nb; x++) {
  78. for (int i = 0; i < nrows_interleaved; i++ ) {
  79. quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
  80. }
  81. if (nrows_interleaved == 8) {
  82. *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
  83. out_ptr = (block_q4_0x8 *) out_ptr + 1;
  84. }
  85. else if (nrows_interleaved == 4) {
  86. *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
  87. out_ptr = (block_q4_0x4 *) out_ptr + 1;
  88. }
  89. }
  90. }
  91. return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
  92. }
  93. size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  94. UNUSED(quant_weights);
  95. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
  96. }
  97. size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  98. UNUSED(quant_weights);
  99. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
  100. }
  101. size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  102. UNUSED(quant_weights);
  103. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
  104. }