Răsfoiți Sursa

cuda : add missing F32<->I32 entries in ggml_cuda_cpy_fn (#16060)

Sigbjørn Skjæret 4 luni în urmă
părinte
comite
ad6bd9083b
1 a modificat fișierele cu 4 adăugiri și 0 ștergeri
  1. 4 0
      ggml/src/ggml-cuda/cpy.cu

+ 4 - 0
ggml/src/ggml-cuda/cpy.cu

@@ -441,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
         return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
+        return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
+    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));