|
@@ -372,6 +372,8 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_SET_I32,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_SET_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
|
@@ -940,6 +942,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
@@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
return false;
|
|
return false;
|
|
|
};
|
|
};
|
|
|
}
|
|
}
|
|
|
|
|
+ case GGML_OP_SET:
|
|
|
|
|
+ {
|
|
|
|
|
+ switch (op->src[0]->type) {
|
|
|
|
|
+ case GGML_TYPE_F32:
|
|
|
|
|
+ case GGML_TYPE_I32:
|
|
|
|
|
+ return true;
|
|
|
|
|
+ default:
|
|
|
|
|
+ return false;
|
|
|
|
|
+ };
|
|
|
|
|
+ }
|
|
|
case GGML_OP_DIAG_MASK_INF:
|
|
case GGML_OP_DIAG_MASK_INF:
|
|
|
case GGML_OP_GET_ROWS:
|
|
case GGML_OP_GET_ROWS:
|
|
|
{
|
|
{
|
|
@@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_SET:
|
|
|
|
|
+ {
|
|
|
|
|
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
|
|
|
|
+
|
|
|
|
|
+ // src0 and dst as viewed during set
|
|
|
|
|
+ const size_t dst_nb0 = ggml_element_size(src0);
|
|
|
|
|
+
|
|
|
|
|
+ const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
|
|
|
|
|
+ const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
|
|
|
|
|
+ const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
|
|
|
|
|
+ const size_t offset = ((int32_t *) dst->op_params)[3];
|
|
|
|
|
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
|
|
|
+
|
|
|
|
|
+ if (!inplace) {
|
|
|
|
|
+ memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
|
|
|
|
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
|
|
|
|
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
|
|
|
|
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
|
|
|
|
|
+
|
|
|
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
+
|
|
|
|
|
+ switch (src0t) {
|
|
|
|
|
+ case GGML_TYPE_F32:
|
|
|
|
|
+ GGML_ASSERT(nb10 == sizeof(float));
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
|
|
|
|
|
+ case GGML_TYPE_I32:
|
|
|
|
|
+ GGML_ASSERT(nb10 == sizeof(int32_t));
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
|
|
|
|
|
+ default: GGML_ABORT("fatal error");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ggml_metal_kargs_set args = {
|
|
|
|
|
+ /*.ne10 =*/ ne10,
|
|
|
|
|
+ /*.ne11 =*/ ne11,
|
|
|
|
|
+ /*.ne12 =*/ ne12,
|
|
|
|
|
+ /*.nb10 =*/ nb10,
|
|
|
|
|
+ /*.nb11 =*/ nb11,
|
|
|
|
|
+ /*.nb12 =*/ nb12,
|
|
|
|
|
+ /*.nb13 =*/ nb13,
|
|
|
|
|
+ /*.nb1 =*/ dst_nb1,
|
|
|
|
|
+ /*.nb2 =*/ dst_nb2,
|
|
|
|
|
+ /*.nb3 =*/ dst_nb3,
|
|
|
|
|
+ /*.offs =*/ offset,
|
|
|
|
|
+ /*.inplace =*/ inplace,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
|
|
|
|
|
+
|
|
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
|
|
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
|
|
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
|
|
|
+
|
|
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_POOL_2D:
|
|
case GGML_OP_POOL_2D:
|
|
|
{
|
|
{
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|