|
|
@@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
|
|
|
*dst_h = __float2half(*src_f);
|
|
|
}
|
|
|
|
|
|
+template<>
|
|
|
+__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
|
|
|
+ *dst_b = *src_f;
|
|
|
+}
|
|
|
+
|
|
|
template<>
|
|
|
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
|
|
*dst_f = *src_f;
|
|
|
@@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
nb1, nb2, nb3,
|
|
|
stream
|
|
|
);
|
|
|
+ } else if (dst->type == GGML_TYPE_BF16) {
|
|
|
+ set_rows_cuda(
|
|
|
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
|
|
|
+ ne00, ne01, ne02, ne03,
|
|
|
+ ne10, ne11, ne12, ne13,
|
|
|
+ nb01, nb02, nb03,
|
|
|
+ nb10, nb11, nb12,
|
|
|
+ nb1, nb2, nb3,
|
|
|
+ stream
|
|
|
+ );
|
|
|
} else {
|
|
|
GGML_ABORT("unsupported type");
|
|
|
}
|