Bläddra i källkod

ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060)

ggml-ci
Georgi Gerganov 2 år sedan
förälder
incheckning
3d68f364f1
8 ändrade filer med 684 tillägg och 838 borttagningar
  1. 104 2
      ggml-cuda.cu
  2. 0 6
      ggml-impl.h
  3. 1 1
      ggml-metal.h
  4. 90 16
      ggml-metal.m
  5. 107 1
      ggml-metal.metal
  6. 168 73
      ggml-quants.c
  7. 201 733
      ggml.c
  8. 13 6
      ggml.h

+ 104 - 2
ggml-cuda.cu

@@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     *dsti = __float2half(*xi);
 }
 
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
                                    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
+static  __global__ void im2col_f32_f16(
+        const float * x, half * dst,
+        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
+    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+
+    const int offset_dst =
+        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
+        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = __float2half(0.0f);
+    } else {
+        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+    }
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 
+static void im2col_f32_f16_cuda(const float * x, half * dst,
+    int OH, int IW, int IH, int OW, int IC,
+    int KH, int KW, int N,  int ofs0, int ofs1,
+    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
+    dim3 block_nums(IC, OH, OW);
+    dim3 block_dims(N,  KH, KW);
+    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 256
 
@@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
-
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
         size_t dst_as = 0;
         half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
 
@@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_im2col(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t N  = src1->ne[is_2D ? 3 : 2];
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
+        OH, IW, IH, OW, IC, KH, KW, N,
+        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+
+    (void) src0;
+    (void) src0_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
                               ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
+void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         return false;
     }
 
+    if (tensor->op == GGML_OP_MUL_MAT) {
+        if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
+#endif
+            return false;
+        }
+    }
+
     switch (tensor->op) {
         case GGML_OP_REPEAT:
             func = ggml_cuda_repeat;
@@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ALIBI:
             func = ggml_cuda_alibi;
             break;
+        case GGML_OP_IM2COL:
+            func = ggml_cuda_im2col;
+            break;
         default:
             return false;
     }

+ 0 - 6
ggml-impl.h

@@ -39,12 +39,6 @@ extern "C" {
 #endif
 #endif
 
-#undef MIN
-#undef MAX
-
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t

+ 1 - 1
ggml-metal.h

@@ -26,7 +26,7 @@
 #include <stdbool.h>
 
 // max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 16
+#define GGML_METAL_MAX_BUFFERS 64
 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 struct ggml_tensor;

+ 90 - 16
ggml-metal.m

@@ -86,6 +86,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(im2col_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -126,7 +128,7 @@ struct ggml_metal_context {
 // MSL code
 // TODO: move the contents here when ready
 //       for now it is easier to work in a separate file
-static NSString * const msl_library_source = @"see metal.metal";
+//static NSString * const msl_library_source = @"see metal.metal";
 
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
@@ -142,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
     ggml_metal_log_user_data = user_data;
 }
 
-static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     if (ggml_metal_log_callback != NULL) {
         va_list args;
         va_start(args, format);
@@ -210,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         } else {
             GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
 
-            NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            NSString * sourcePath;
+            NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+            if (ggmlMetalPathResources) {
+                sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
+            } else {
+                sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            }
             if (sourcePath == nil) {
                 GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
                 sourcePath = @"ggml-metal.metal";
@@ -281,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -311,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(im2col_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -329,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
         if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
             break;
         }
     }
@@ -380,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -410,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
+    GGML_METAL_DEL_KERNEL(im2col_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -467,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
+    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
+        ctx = t->buffer->backend->context;
+    }
+
     // find the view that contains the tensor fully
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -567,7 +584,7 @@ bool ggml_metal_add_buffer(
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
-            GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
+            GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
         } else {
             GGML_METAL_LOG_INFO("\n");
         }
@@ -1024,7 +1041,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1133,6 +1150,7 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                         } break;
@@ -1140,13 +1158,18 @@ void ggml_metal_graph_compute(
                                         {
                                             nth0 = 32;
                                             nth1 = 1;
-                                            if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
-                                                nrows = ne11;
+                                            if (src1t == GGML_TYPE_F32) {
+                                                if (ne11 * ne12 < 4) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+                                                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+                                                    nrows = ne11;
+                                                } else {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                    nrows = 4;
+                                                }
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1336,7 +1359,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1355,7 +1378,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1410,8 +1433,7 @@ void ggml_metal_graph_compute(
                             const int n_past     = ((int32_t *) dst->op_params)[0];
                             const int n_dims     = ((int32_t *) dst->op_params)[1];
                             const int mode       = ((int32_t *) dst->op_params)[2];
-                            // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                            const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
 
                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1459,6 +1481,58 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_IM2COL:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                            GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+                            const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                            const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                            const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                            const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                            const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                            const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+                            const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+                            const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                            const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                            const int32_t IH = is_2D ? src1->ne[1] : 1;
+                            const int32_t IW =         src1->ne[0];
+
+                            const int32_t KH = is_2D ? src0->ne[1] : 1;
+                            const int32_t KW =         src0->ne[0];
+
+                            const int32_t OH = is_2D ? dst->ne[2] : 1;
+                            const int32_t OW =         dst->ne[1];
+
+                            const int32_t CHW = IC * KH * KW;
+
+                            const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                            const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
+                            [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
+                            [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
+                            [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
+                            [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
+                            [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
+                            [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
+                            [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
+                            [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
+                            [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:

+ 107 - 1
ggml-metal.metal

@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F16;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
+            device const half4 * y4 = (device const half4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
+kernel void kernel_im2col_f16(
+        device const float * x,
+        device       half * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,

+ 168 - 73
ggml-quants.c

@@ -14,26 +14,6 @@
 //
 #include <arm_neon.h>
 
-#if !defined(__aarch64__)
-inline static int32_t vaddvq_s16(int16x8_t v) {
-    return
-        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
-        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
-        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
-        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
-    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
-    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
-    return vcombine_s16(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
-    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-#endif
-
 #else
 
 #ifdef __wasm_simd128__
@@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <intrin.h>
 #else
-#if !defined(__riscv) && !defined(__s390__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
 #include <immintrin.h>
 #endif
 #endif
 #endif
 #endif
 #endif
+#endif
 
 #ifdef __riscv_v_intrinsic
 #include <riscv_vector.h>
@@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 
 #undef MIN
 #undef MAX
+
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
@@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
 #if defined(__ARM_NEON)
-
 #if !defined(__aarch64__)
 
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
 inline static int32_t vaddvq_s32(int32x4_t v) {
     return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
 }
@@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
     return res;
 }
 
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vld1q_s16(ptr + 0);
+    res.val[1] = vld1q_s16(ptr + 8);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+    res.val[2] = vld1q_u8(ptr + 32);
+    res.val[3] = vld1q_u8(ptr + 48);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+    int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+    ggml_int8x16x2_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+    res.val[2] = vld1q_s8(ptr + 32);
+    res.val[3] = vld1q_s8(ptr + 48);
+
+    return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t  int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t  int8x16x2_t
+#define ggml_int8x16x4_t  int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2  vld1q_u8_x2
+#define ggml_vld1q_u8_x4  vld1q_u8_x4
+#define ggml_vld1q_s8_x2  vld1q_s8_x2
+#define ggml_vld1q_s8_x4  vld1q_s8_x4
+
 #endif
 #endif
 
@@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x2_t q2bytes;
+    ggml_int8x16x2_t q2bytes;
     uint8_t aux[16];
 
     float sum = 0;
@@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
         vst1q_u8(aux, scales);
 
         const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
-        const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
         const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
                                        vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
         const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 
 #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
-        q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
         MULTIPLY_ACCUM_WITH_SCALE((index));
@@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
 
-            int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
             q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
             MULTIPLY_ACCUM_WITH_SCALE(0);
@@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q2bytes;
+    ggml_int8x16x4_t q2bytes;
 
     uint32_t aux32[2];
     const uint8_t * scales = (const uint8_t *)aux32;
@@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
         const uint8x16_t q2bits = vld1q_u8(q2);
 
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3 = vshlq_n_u8(m0, 3);
     const int8_t m32 = 32;
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
     float sum = 0;
 
@@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].hmask;
         const int8_t  * restrict q8 = y[i].qs;
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
         int32_t isum = 0;
 
@@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
-            const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
-            const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
             q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3b = vdupq_n_u8(0x3);
     const uint8x16_t mh  = vdupq_n_u8(4);
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
     uint16_t aux16[2];
     int8_t * scales = (int8_t *)aux16;
@@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
     for (int i = 0; i < nb; ++i) {
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
         const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
         const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
 
         const uint16_t a = *(const uint16_t *)x[i].scales;
         aux16[0] = a & 0x0f0f;
@@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x2_t q4bytes;
-    int8x16x2_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
 
     float sumf = 0;
 
@@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
 
 #ifdef __ARM_FEATURE_DOTPROD
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
             const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
             sumi1 += vaddvq_s32(p1) * scales[2*j+0];
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
 
@@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
             sumi2 += vaddvq_s32(p2) * scales[2*j+1];
 #else
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
             const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
                                            vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
             sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
             const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
     float sumf = 0;
 
-    int8x16x2_t q4bytes;
-    int8x16x4_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x4_t q8bytes;
 
     float sum_mins = 0.f;
 
@@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
         const float d = y[i].d * (float)x[i].d[0];
 
-        const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+        const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
 
 #ifdef __ARM_FEATURE_DOTPROD
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
@@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
         const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
 
 #else
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
         const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q5bytes;
+    ggml_int8x16x4_t q5bytes;
 
     float sumf = 0;
 
@@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict q8 = y[i].qs;
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
-        uint8x16x4_t q5h;
+        ggml_uint8x16x4_t q5h;
 
         int32_t sumi = 0;
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
-            const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q5bytes;
-    uint8x16x4_t q5h;
+    ggml_int8x16x4_t q5bytes;
+    ggml_uint8x16x4_t q5h;
 
     float sumf = 0;
 
@@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
         const uint8x8_t qhbits = vld1_u8(qh);
 
-        const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
         q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
     const uint8x16_t mone = vdupq_n_u8(3);
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
     for (int i = 0; i < nb; ++i) {
 
@@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         const int8_t * restrict scale = x[i].scales;
 
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
         const int8x16_t scales = vld1q_s8(scale);
-        const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+        const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
 
         const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
                                                    vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
-            uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
-            int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
             scale += 2;
 #endif
 
-            q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             shifted = vshrq_n_u8(qhbits.val[0], 4);
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
     const uint8x16_t mone = vdupq_n_u8(3);
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
     for (int i = 0; i < nb; ++i) {
 
@@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         int32_t isum = 0;
 
-        uint8x16_t   qhbits = vld1q_u8(qh);
-        uint8x16x2_t q6bits = vld1q_u8_x2(q6);
-        int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        uint8x16_t qhbits = vld1q_u8(qh);
+        ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+        ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
         uint8x16_t shifted = vshrq_n_u8(qhbits, 2);

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 201 - 733
ggml.c


+ 13 - 6
ggml.h

@@ -403,13 +403,8 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D,
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
         GGML_OP_CONV_TRANSPOSE_1D,
-        GGML_OP_CONV_2D,
-        GGML_OP_CONV_2D_STAGE_0, // internal
-        GGML_OP_CONV_2D_STAGE_1, // internal
+        GGML_OP_IM2COL,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
@@ -1403,6 +1398,18 @@ extern "C" {
             float                 min,
             float                 max);
 
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D);
+
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,

Vissa filer visades inte eftersom för många filer har ändrats