|
|
@@ -579,13 +579,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
|
|
|
|
|
|
const src_t * x = (const src_t *) vx;
|
|
|
|
|
|
- if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
|
|
|
- y[i] = __bfloat162float(x[i]);
|
|
|
- } else if constexpr (std::is_same_v<dst_t, nv_bfloat16> && std::is_same_v<src_t, half>) {
|
|
|
- y[i] = (float)x[i];
|
|
|
- } else {
|
|
|
- y[i] = x[i];
|
|
|
- }
|
|
|
+ y[i] = float(x[i]);
|
|
|
}
|
|
|
|
|
|
template <typename src_t, typename dst_t>
|