|
|
@@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
|
|
dst[index] = result;
|
|
|
}
|
|
|
|
|
|
+// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
|
|
+// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
|
|
+static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
|
|
|
+ const int nb00, const int nb01, const int nb02, const int nb03,
|
|
|
+ const int ne00_src, const int ne01_src,
|
|
|
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
|
|
+ const float sf0, const float sf1, const float sf2, const float sf3,
|
|
|
+ const float pixel_offset) {
|
|
|
+ const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
|
|
+
|
|
|
+ if (index >= dst_total_elements) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int i10_dst = index % ne10_dst;
|
|
|
+ const int i11_dst = (index / ne10_dst) % ne11_dst;
|
|
|
+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
|
|
+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
|
|
+
|
|
|
+ const int i02_src = (int)(i12_dst / sf2);
|
|
|
+ const int i03_src = (int)(i13_dst / sf3);
|
|
|
+
|
|
|
+ const float y = ((float)i11_dst + pixel_offset) / sf1;
|
|
|
+ const float x = ((float)i10_dst + pixel_offset) / sf0;
|
|
|
+
|
|
|
+ // support and invscale, minimum 1 pixel for bilinear
|
|
|
+ const float support1 = max(1.0f / sf1, 1.0f);
|
|
|
+ const float invscale1 = 1.0f / support1;
|
|
|
+ const float support0 = max(1.0f / sf0, 1.0f);
|
|
|
+ const float invscale0 = 1.0f / support0;
|
|
|
+
|
|
|
+ // the range of source pixels that contribute
|
|
|
+ const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
|
|
|
+ const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
|
|
|
+ const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
|
|
|
+ const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
|
|
|
+
|
|
|
+ // bilinear filter with antialiasing
|
|
|
+ float val = 0.0f;
|
|
|
+ float total_weight = 0.0f;
|
|
|
+
|
|
|
+ auto triangle_filter = [](float x) -> float {
|
|
|
+ return max(1.0f - fabsf(x), 0.0f);
|
|
|
+ };
|
|
|
+
|
|
|
+ for (int64_t sy = y_min; sy < y_max; sy++) {
|
|
|
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
|
|
+
|
|
|
+ for (int64_t sx = x_min; sx < x_max; sx++) {
|
|
|
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
|
|
+ const float weight = weight_x * weight_y;
|
|
|
+
|
|
|
+ if (weight <= 0.0f) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
|
|
|
+ val += pixel * weight;
|
|
|
+ total_weight += weight;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (total_weight > 0.0f) {
|
|
|
+ val /= total_weight;
|
|
|
+ }
|
|
|
+
|
|
|
+ dst[index] = val;
|
|
|
+}
|
|
|
+
|
|
|
namespace bicubic_interpolation {
|
|
|
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
|
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
|
@@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
|
|
const int ne00_src, const int ne01_src,
|
|
|
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
|
|
const float sf0, const float sf1, const float sf2, const float sf3,
|
|
|
- const float pixel_offset, cudaStream_t stream) {
|
|
|
+ const float pixel_offset, bool antialias, cudaStream_t stream) {
|
|
|
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
|
|
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
|
|
|
|
|
- upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
|
|
+ if (antialias) {
|
|
|
+ upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
|
|
+ } else {
|
|
|
+ upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
|
|
|
@@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
|
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
|
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
|
+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
|
|
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
|
|
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
|
|
- sf0, sf1, sf2, sf3, pixel_offset, stream);
|
|
|
+ sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
|
|
|
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
|
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
|
|
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|