|
|
@@ -1,4 +1,5 @@
|
|
|
#include "conv2d.cuh"
|
|
|
+#include "convert.cuh"
|
|
|
|
|
|
struct conv_params {
|
|
|
const int64_t IW, IH;
|
|
|
@@ -94,8 +95,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
|
|
|
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
|
|
|
|
|
|
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
|
|
|
- const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
|
|
- acc += (input_val * kernel_val);
|
|
|
+ const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
|
|
+ acc += (input_val * ggml_cuda_cast<float>(kernel_val));
|
|
|
}
|
|
|
}
|
|
|
}
|