|
|
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|
|
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
|
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
break;
|
|
|
+ case GGML_TYPE_I32:
|
|
|
+ get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
|
|
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
case GGML_TYPE_BF16:
|
|
|
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
|
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
@@ -210,6 +214,10 @@ void get_rows_cuda(
|
|
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
|
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
break;
|
|
|
+ case GGML_TYPE_I32:
|
|
|
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
|
|
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
+ break;
|
|
|
case GGML_TYPE_F16:
|
|
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
|
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|