Просмотр исходного кода

ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060)

ggml-ci
Georgi Gerganov 2 лет назад
Родитель
Сommit
3d68f364f1
8 измененных файлов с 684 добавлено и 838 удалено
  1. 104 2
      ggml-cuda.cu
  2. 0 6
      ggml-impl.h
  3. 1 1
      ggml-metal.h
  4. 90 16
      ggml-metal.m
  5. 107 1
      ggml-metal.metal
  6. 168 73
      ggml-quants.c
  7. 201 733
      ggml.c
  8. 13 6
      ggml.h

+ 104 - 2
ggml-cuda.cu

@@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     *dsti = __float2half(*xi);
     *dsti = __float2half(*xi);
 }
 }
 
 
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
 template <cpy_kernel_t cpy_1>
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
                                    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
                                    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 }
 
 
+static  __global__ void im2col_f32_f16(
+        const float * x, half * dst,
+        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
+    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+
+    const int offset_dst =
+        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
+        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = __float2half(0.0f);
+    } else {
+        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+    }
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 }
 
 
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 }
 
 
+static void im2col_f32_f16_cuda(const float * x, half * dst,
+    int OH, int IW, int IH, int OW, int IC,
+    int KH, int KW, int N,  int ofs0, int ofs1,
+    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
+    dim3 block_nums(IC, OH, OW);
+    dim3 block_dims(N,  KH, KW);
+    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
 // buffer pool for cuda
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 256
 #define MAX_CUDA_BUFFERS 256
 
 
@@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
             src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
         }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
-
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
         size_t dst_as = 0;
         size_t dst_as = 0;
         half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
         half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
 
 
@@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
     (void) src1_dd;
     (void) src1_dd;
 }
 }
 
 
+inline void ggml_cuda_op_im2col(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t N  = src1->ne[is_2D ? 3 : 2];
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
+        OH, IW, IH, OW, IC, KH, KW, N,
+        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+
+    (void) src0;
+    (void) src0_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
                               ne10, ne11, nb10, nb11, nb12, main_stream);
                               ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 }
 
 
+void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src0;
     (void) src1;
     (void) src1;
@@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         return false;
         return false;
     }
     }
 
 
+    if (tensor->op == GGML_OP_MUL_MAT) {
+        if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
+#endif
+            return false;
+        }
+    }
+
     switch (tensor->op) {
     switch (tensor->op) {
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT:
             func = ggml_cuda_repeat;
             func = ggml_cuda_repeat;
@@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ALIBI:
         case GGML_OP_ALIBI:
             func = ggml_cuda_alibi;
             func = ggml_cuda_alibi;
             break;
             break;
+        case GGML_OP_IM2COL:
+            func = ggml_cuda_im2col;
+            break;
         default:
         default:
             return false;
             return false;
     }
     }

+ 0 - 6
ggml-impl.h

@@ -39,12 +39,6 @@ extern "C" {
 #endif
 #endif
 #endif
 #endif
 
 
-#undef MIN
-#undef MAX
-
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
 // 16-bit float
 // 16-bit float
 // on Arm, we use __fp16
 // on Arm, we use __fp16
 // on x86, we use uint16_t
 // on x86, we use uint16_t

+ 1 - 1
ggml-metal.h

@@ -26,7 +26,7 @@
 #include <stdbool.h>
 #include <stdbool.h>
 
 
 // max memory buffers that can be mapped to the device
 // max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 16
+#define GGML_METAL_MAX_BUFFERS 64
 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 
 struct ggml_tensor;
 struct ggml_tensor;

+ 90 - 16
ggml-metal.m

@@ -86,6 +86,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(im2col_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -126,7 +128,7 @@ struct ggml_metal_context {
 // MSL code
 // MSL code
 // TODO: move the contents here when ready
 // TODO: move the contents here when ready
 //       for now it is easier to work in a separate file
 //       for now it is easier to work in a separate file
-static NSString * const msl_library_source = @"see metal.metal";
+//static NSString * const msl_library_source = @"see metal.metal";
 
 
 // Here to assist with NSBundle Path Hack
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
 @interface GGMLMetalClass : NSObject
@@ -142,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
     ggml_metal_log_user_data = user_data;
     ggml_metal_log_user_data = user_data;
 }
 }
 
 
-static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     if (ggml_metal_log_callback != NULL) {
     if (ggml_metal_log_callback != NULL) {
         va_list args;
         va_list args;
         va_start(args, format);
         va_start(args, format);
@@ -210,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         } else {
         } else {
             GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
             GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
 
 
-            NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            NSString * sourcePath;
+            NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+            if (ggmlMetalPathResources) {
+                sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
+            } else {
+                sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            }
             if (sourcePath == nil) {
             if (sourcePath == nil) {
                 GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
                 GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
                 sourcePath = @"ggml-metal.metal";
                 sourcePath = @"ggml-metal.metal";
@@ -281,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -311,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(im2col_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -329,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
     for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
         if ([ctx->device supportsFamily:i]) {
         if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
             break;
             break;
         }
         }
     }
     }
@@ -380,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -410,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(alibi_f32);
+    GGML_METAL_DEL_KERNEL(im2col_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -467,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
 
     const int64_t tsize = ggml_nbytes(t);
     const int64_t tsize = ggml_nbytes(t);
 
 
+    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
+        ctx = t->buffer->backend->context;
+    }
+
     // find the view that contains the tensor fully
     // find the view that contains the tensor fully
     for (int i = 0; i < ctx->n_buffers; ++i) {
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -567,7 +584,7 @@ bool ggml_metal_add_buffer(
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
-            GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
+            GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
         } else {
         } else {
             GGML_METAL_LOG_INFO("\n");
             GGML_METAL_LOG_INFO("\n");
         }
         }
@@ -1024,7 +1041,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                         } break;
@@ -1133,6 +1150,7 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                     case GGML_TYPE_F32:
                                         {
                                         {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                             nrows = 4;
                                         } break;
                                         } break;
@@ -1140,13 +1158,18 @@ void ggml_metal_graph_compute(
                                         {
                                         {
                                             nth0 = 32;
                                             nth0 = 32;
                                             nth1 = 1;
                                             nth1 = 1;
-                                            if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
-                                                nrows = ne11;
+                                            if (src1t == GGML_TYPE_F32) {
+                                                if (ne11 * ne12 < 4) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+                                                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+                                                    nrows = ne11;
+                                                } else {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                    nrows = 4;
+                                                }
                                             } else {
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
                                                 nrows = 4;
                                                 nrows = 4;
                                             }
                                             }
                                         } break;
                                         } break;
@@ -1336,7 +1359,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
 
                             const int64_t nrows = ggml_nrows(src0);
                             const int64_t nrows = ggml_nrows(src0);
 
 
@@ -1355,7 +1378,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
 
 
                             const int64_t nrows = ggml_nrows(src0);
                             const int64_t nrows = ggml_nrows(src0);
 
 
@@ -1410,8 +1433,7 @@ void ggml_metal_graph_compute(
                             const int n_past     = ((int32_t *) dst->op_params)[0];
                             const int n_past     = ((int32_t *) dst->op_params)[0];
                             const int n_dims     = ((int32_t *) dst->op_params)[1];
                             const int n_dims     = ((int32_t *) dst->op_params)[1];
                             const int mode       = ((int32_t *) dst->op_params)[2];
                             const int mode       = ((int32_t *) dst->op_params)[2];
-                            // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                            const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
 
 
                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1459,6 +1481,58 @@ void ggml_metal_graph_compute(
 
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                         } break;
+                    case GGML_OP_IM2COL:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                            GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+                            const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                            const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                            const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                            const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                            const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                            const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+                            const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+                            const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                            const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                            const int32_t IH = is_2D ? src1->ne[1] : 1;
+                            const int32_t IW =         src1->ne[0];
+
+                            const int32_t KH = is_2D ? src0->ne[1] : 1;
+                            const int32_t KW =         src0->ne[0];
+
+                            const int32_t OH = is_2D ? dst->ne[2] : 1;
+                            const int32_t OW =         dst->ne[1];
+
+                            const int32_t CHW = IC * KH * KW;
+
+                            const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                            const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
+                            [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
+                            [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
+                            [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
+                            [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
+                            [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
+                            [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
+                            [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
+                            [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
+                            [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                     case GGML_OP_CONT:

+ 107 - 1
ggml-metal.metal

@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
         constant   int64_t & ne0,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
 
     const int64_t r0 = tgpig.x;
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F32_F32;
     const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
     }
     }
 }
 }
 
 
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F16;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
+            device const half4 * y4 = (device const half4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mv_f16_f32_1row(
 kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src0,
         device const  char * src1,
         device const  char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
 
+kernel void kernel_im2col_f16(
+        device const float * x,
+        device       half * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
 kernel void kernel_cpy_f16_f16(
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device const half * src0,
         device       half * dst,
         device       half * dst,

+ 168 - 73
ggml-quants.c

@@ -14,26 +14,6 @@
 //
 //
 #include <arm_neon.h>
 #include <arm_neon.h>
 
 
-#if !defined(__aarch64__)
-inline static int32_t vaddvq_s16(int16x8_t v) {
-    return
-        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
-        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
-        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
-        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
-    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
-    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
-    return vcombine_s16(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
-    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-#endif
-
 #else
 #else
 
 
 #ifdef __wasm_simd128__
 #ifdef __wasm_simd128__
@@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <intrin.h>
 #include <intrin.h>
 #else
 #else
-#if !defined(__riscv) && !defined(__s390__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
 #include <immintrin.h>
 #include <immintrin.h>
 #endif
 #endif
 #endif
 #endif
 #endif
 #endif
 #endif
 #endif
 #endif
 #endif
+#endif
 
 
 #ifdef __riscv_v_intrinsic
 #ifdef __riscv_v_intrinsic
 #include <riscv_vector.h>
 #include <riscv_vector.h>
@@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 
 
 #undef MIN
 #undef MIN
 #undef MAX
 #undef MAX
+
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
 
@@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
 
 #if defined(__ARM_NEON)
 #if defined(__ARM_NEON)
-
 #if !defined(__aarch64__)
 #if !defined(__aarch64__)
 
 
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
 inline static int32_t vaddvq_s32(int32x4_t v) {
 inline static int32_t vaddvq_s32(int32x4_t v) {
     return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
     return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
 }
 }
@@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
     return res;
     return res;
 }
 }
 
 
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vld1q_s16(ptr + 0);
+    res.val[1] = vld1q_s16(ptr + 8);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+    res.val[2] = vld1q_u8(ptr + 32);
+    res.val[3] = vld1q_u8(ptr + 48);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+    int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+    ggml_int8x16x2_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+    res.val[2] = vld1q_s8(ptr + 32);
+    res.val[3] = vld1q_s8(ptr + 48);
+
+    return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t  int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t  int8x16x2_t
+#define ggml_int8x16x4_t  int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2  vld1q_u8_x2
+#define ggml_vld1q_u8_x4  vld1q_u8_x4
+#define ggml_vld1q_s8_x2  vld1q_s8_x2
+#define ggml_vld1q_s8_x4  vld1q_s8_x4
+
 #endif
 #endif
 #endif
 #endif
 
 
@@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 #endif
 
 
-    int8x16x2_t q2bytes;
+    ggml_int8x16x2_t q2bytes;
     uint8_t aux[16];
     uint8_t aux[16];
 
 
     float sum = 0;
     float sum = 0;
@@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
         vst1q_u8(aux, scales);
         vst1q_u8(aux, scales);
 
 
         const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
         const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
-        const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
         const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
         const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
                                        vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
                                        vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
         const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
         const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 #endif
 
 
 #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
 #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
-        q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
         MULTIPLY_ACCUM_WITH_SCALE((index));
         MULTIPLY_ACCUM_WITH_SCALE((index));
@@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         for (int j = 0; j < QK_K/128; ++j) {
         for (int j = 0; j < QK_K/128; ++j) {
 
 
-            const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
 
 
-            int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
             q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
             q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
             q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
             MULTIPLY_ACCUM_WITH_SCALE(0);
             MULTIPLY_ACCUM_WITH_SCALE(0);
@@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 #endif
 
 
-    int8x16x4_t q2bytes;
+    ggml_int8x16x4_t q2bytes;
 
 
     uint32_t aux32[2];
     uint32_t aux32[2];
     const uint8_t * scales = (const uint8_t *)aux32;
     const uint8_t * scales = (const uint8_t *)aux32;
@@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         const uint8x16_t q2bits = vld1q_u8(q2);
         const uint8x16_t q2bits = vld1q_u8(q2);
 
 
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
 
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3 = vshlq_n_u8(m0, 3);
     const uint8x16_t m3 = vshlq_n_u8(m0, 3);
     const int8_t m32 = 32;
     const int8_t m32 = 32;
 
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
 
     float sum = 0;
     float sum = 0;
 
 
@@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].hmask;
         const uint8_t * restrict qh = x[i].hmask;
         const int8_t  * restrict q8 = y[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
 
         int32_t isum = 0;
         int32_t isum = 0;
 
 
@@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         for (int j = 0; j < QK_K/128; ++j) {
         for (int j = 0; j < QK_K/128; ++j) {
 
 
-            const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
-            const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
-            const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
 
 
             q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
             q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
             q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
             q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3b = vdupq_n_u8(0x3);
     const uint8x16_t m3b = vdupq_n_u8(0x3);
     const uint8x16_t mh  = vdupq_n_u8(4);
     const uint8x16_t mh  = vdupq_n_u8(4);
 
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
 
     uint16_t aux16[2];
     uint16_t aux16[2];
     int8_t * scales = (int8_t *)aux16;
     int8_t * scales = (int8_t *)aux16;
@@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     for (int i = 0; i < nb; ++i) {
     for (int i = 0; i < nb; ++i) {
 
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
 
         const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
         const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
         const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
         const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
 
 
         const uint16_t a = *(const uint16_t *)x[i].scales;
         const uint16_t a = *(const uint16_t *)x[i].scales;
         aux16[0] = a & 0x0f0f;
         aux16[0] = a & 0x0f0f;
@@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 #endif
 
 
-    int8x16x2_t q4bytes;
-    int8x16x2_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
 
 
     float sumf = 0;
     float sumf = 0;
 
 
@@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         for (int j = 0; j < QK_K/64; ++j) {
         for (int j = 0; j < QK_K/64; ++j) {
 
 
-            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
 
 
 #ifdef __ARM_FEATURE_DOTPROD
 #ifdef __ARM_FEATURE_DOTPROD
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
 
             const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
             const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
             sumi1 += vaddvq_s32(p1) * scales[2*j+0];
             sumi1 += vaddvq_s32(p1) * scales[2*j+0];
 
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
 
 
@@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
 
             sumi2 += vaddvq_s32(p2) * scales[2*j+1];
             sumi2 += vaddvq_s32(p2) * scales[2*j+1];
 #else
 #else
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
             const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
             const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
                                            vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
                                            vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
             sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
             sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
 
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
             const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
             const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     float sumf = 0;
     float sumf = 0;
 
 
-    int8x16x2_t q4bytes;
-    int8x16x4_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x4_t q8bytes;
 
 
     float sum_mins = 0.f;
     float sum_mins = 0.f;
 
 
@@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         const float d = y[i].d * (float)x[i].d[0];
         const float d = y[i].d * (float)x[i].d[0];
 
 
-        const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+        const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
 
 
 #ifdef __ARM_FEATURE_DOTPROD
 #ifdef __ARM_FEATURE_DOTPROD
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
 
@@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
         const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
         const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
 
 
 #else
 #else
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
         const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
         const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 #endif
 
 
-    int8x16x4_t q5bytes;
+    ggml_int8x16x4_t q5bytes;
 
 
     float sumf = 0;
     float sumf = 0;
 
 
@@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].qh;
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict q8 = y[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
 
-        uint8x16x4_t q5h;
+        ggml_uint8x16x4_t q5h;
 
 
         int32_t sumi = 0;
         int32_t sumi = 0;
 
 
         for (int j = 0; j < QK_K/64; ++j) {
         for (int j = 0; j < QK_K/64; ++j) {
 
 
-            const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
-            const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
 
             q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
             q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 #endif
 
 
-    int8x16x4_t q5bytes;
-    uint8x16x4_t q5h;
+    ggml_int8x16x4_t q5bytes;
+    ggml_uint8x16x4_t q5h;
 
 
     float sumf = 0;
     float sumf = 0;
 
 
@@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         const uint8x8_t qhbits = vld1_u8(qh);
         const uint8x8_t qhbits = vld1_u8(qh);
 
 
-        const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
 
         const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
         const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
         q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
         q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     const uint8x16_t mone = vdupq_n_u8(3);
     const uint8x16_t mone = vdupq_n_u8(3);
 
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
 
     for (int i = 0; i < nb; ++i) {
     for (int i = 0; i < nb; ++i) {
 
 
@@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         const int8_t * restrict scale = x[i].scales;
         const int8_t * restrict scale = x[i].scales;
 
 
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
         const int8x16_t scales = vld1q_s8(scale);
         const int8x16_t scales = vld1q_s8(scale);
-        const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+        const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
 
 
         const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
         const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
                                                    vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
                                                    vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         for (int j = 0; j < QK_K/128; ++j) {
         for (int j = 0; j < QK_K/128; ++j) {
 
 
-            uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
-            uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
-            int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
 
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
             q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
             scale += 2;
             scale += 2;
 #endif
 #endif
 
 
-            q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
 
             shifted = vshrq_n_u8(qhbits.val[0], 4);
             shifted = vshrq_n_u8(qhbits.val[0], 4);
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     const uint8x16_t mone = vdupq_n_u8(3);
     const uint8x16_t mone = vdupq_n_u8(3);
 
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
 
     for (int i = 0; i < nb; ++i) {
     for (int i = 0; i < nb; ++i) {
 
 
@@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
         int32_t isum = 0;
         int32_t isum = 0;
 
 
-        uint8x16_t   qhbits = vld1q_u8(qh);
-        uint8x16x2_t q6bits = vld1q_u8_x2(q6);
-        int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        uint8x16_t qhbits = vld1q_u8(qh);
+        ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+        ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
 
         q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
         q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
         uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
         uint8x16_t shifted = vshrq_n_u8(qhbits, 2);

Разница между файлами не показана из-за своего большого размера
+ 201 - 733
ggml.c


+ 13 - 6
ggml.h

@@ -403,13 +403,8 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D,
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
         GGML_OP_CONV_TRANSPOSE_1D,
         GGML_OP_CONV_TRANSPOSE_1D,
-        GGML_OP_CONV_2D,
-        GGML_OP_CONV_2D_STAGE_0, // internal
-        GGML_OP_CONV_2D_STAGE_1, // internal
+        GGML_OP_IM2COL,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_POOL_2D,
@@ -1403,6 +1398,18 @@ extern "C" {
             float                 min,
             float                 min,
             float                 max);
             float                 max);
 
 
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D);
+
     GGML_API struct ggml_tensor * ggml_conv_1d(
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,

Некоторые файлы не были показаны из-за большого количества измененных файлов