Przeglądaj źródła

CUDA: use FP32 arithmetic for conv2d (#15683)

Johannes Gäßler 4 miesięcy temu
rodzic
commit
38ad381f9f
1 zmienionych plików z 4 dodań i 10 usunięć
  1. 4 10
      ggml/src/ggml-cuda/conv2d.cu

+ 4 - 10
ggml/src/ggml-cuda/conv2d.cu

@@ -82,7 +82,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
     int64_t n, c_out, out_y, out_x;
     int64_t n, c_out, out_y, out_x;
     Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
     Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
 
 
-    T acc = 0;
+    float acc = 0.0f;
 
 
     for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
     for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
         kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
         kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
@@ -93,21 +93,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
             for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
             for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
                 const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
                 const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
 
 
-                T input_val;
-                if (std::is_same<T, half>::value) {
-                    input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
-                } else {
-                    input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
-                }
-
-                T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+                const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+                const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
                 acc += (input_val * kernel_val);
                 acc += (input_val * kernel_val);
             }
             }
         }
         }
     }
     }
 
 
     // [N, OC, OH, OW]
     // [N, OC, OH, OW]
-    output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
+    output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
 }
 }
 
 
 template <typename T>
 template <typename T>