Преглед на файлове

[SYCL] Support gpt-oss by OPs add-id, mul_mat for mxfp4, swiglu_oai (#17826)

* support gpt-oss GPU by OP add-id, mul_mat for mxfp4, swiglu_oai, fix warning

* fix fault ut case, update ops.md

* rebase, fix format issue
Neo Zhang Jianyu преди 1 месец
родител
ревизия
4aced7a631

+ 9 - 9
docs/ops.md

@@ -18,12 +18,12 @@ Legend:
 |                              ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                             ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
-|                           ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |  | ✅ | ❌ | ❌ | ❌ |
+|                           ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |  | ✅ | ❌ | ❌ | ❌ |
 |                           ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
-|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 |  | ✅ | ❌ | ❌ | ❌ |
+|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
 |                             CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
-|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
+|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 |  | 🟡 | ❌ | ❌ | ❌ |
 |                           CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                             CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
 |                          CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -31,7 +31,7 @@ Legend:
 |                          CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-|                              COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
+|                              COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ | ❌ | ❌ |
 |                      COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
 |               CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -64,7 +64,7 @@ Legend:
 |                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                          L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
-|                              LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
+|                              LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | ✅ | ❌ | ❌ | ❌ |
 |                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -98,14 +98,14 @@ Legend:
 |                          SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                             SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                        SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-|                              SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
+|                              SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ |  | 🟡 | ❌ | ❌ | ❌ |
 |                          SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                         SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
 |                         SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                    SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
 |                        SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
-|                              SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
-|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
+|                              SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ |  | 🟡 | ❌ | ❌ | ❌ |
+|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ |  | 🟡 | ❌ | ❌ | ❌ |
 |                         SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                         SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
 |                             STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -113,7 +113,7 @@ Legend:
 |                              SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
 |                         SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
 |                           SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
-|                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |  | 🟡 | ✅ | ❌ | ❌ |
+|                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |  | 🟡 | ✅ | ❌ | ❌ |
 |                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                            TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |

Файловите разлики са ограничени, защото са твърде много
+ 392 - 358
docs/ops/SYCL.csv


+ 77 - 0
ggml/src/ggml-sycl/add-id.cpp

@@ -0,0 +1,77 @@
+#include <sycl/sycl.hpp>
+#include "common.hpp"
+#include "add-id.hpp"
+
+static void add_id_kernel(
+    const float* src0,
+    const float* src1,
+    const int32_t* src2,
+    float* dst,
+    int64_t ne0,
+    int64_t ne1,
+    size_t nb01,
+    size_t nb02,
+    size_t nb11,
+    size_t nb21,
+    sycl::nd_item<3> item_ct1) {
+  const int64_t i1 = item_ct1.get_group(2);
+  const int64_t i2 = item_ct1.get_group(1);
+
+  const int i11 =
+      *(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);
+
+  const size_t nb1 = ne0 * sizeof(float);
+  const size_t nb2 = ne1 * nb1;
+
+  float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);
+  const float* src0_row =
+      (const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);
+  const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);
+
+  for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;
+       i0 += item_ct1.get_local_range(2)) {
+    dst_row[i0] = src0_row[i0] + src1_row[i0];
+  }
+}
+
+void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+  const ggml_tensor* src0 = dst->src[0];
+  const ggml_tensor* src1 = dst->src[1];
+  const ggml_tensor* src2 = dst->src[2];
+
+  GGML_TENSOR_TERNARY_OP_LOCALS
+
+  GGML_ASSERT(dst->type == GGML_TYPE_F32);
+  GGML_ASSERT(src0->type == GGML_TYPE_F32);
+  GGML_ASSERT(src1->type == GGML_TYPE_F32);
+  GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+  GGML_ASSERT(nb00 == sizeof(float));
+  GGML_ASSERT(nb10 == sizeof(float));
+  GGML_ASSERT(nb20 == sizeof(int32_t));
+
+  const float* src0_d = (const float*)src0->data;
+  const float* src1_d = (const float*)src1->data;
+  const int32_t* src2_d = (const int32_t*)src2->data;
+  float* dst_d = (float*)dst->data;
+
+  int threads = std::min((int)ne00, 768);  // cols
+  ctx.stream()->parallel_for(
+      sycl::nd_range<3>(
+          sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
+          sycl::range<3>(1, 1, threads)),
+      [=](sycl::nd_item<3> item_ct1) {
+        add_id_kernel(
+            src0_d,
+            src1_d,
+            src2_d,
+            dst_d,
+            ne0,
+            ne1,
+            nb01,
+            nb02,
+            nb11,
+            nb21,
+            item_ct1);
+      });
+}

+ 8 - 0
ggml/src/ggml-sycl/add-id.hpp

@@ -0,0 +1,8 @@
+#ifndef GGML_SYCL_ADD_ID_HPP
+#define GGML_SYCL_ADD_ID_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_add_id(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif // GGML_SYCL_ADD_ID_HPP

+ 17 - 0
ggml/src/ggml-sycl/common.hpp

@@ -642,5 +642,22 @@ static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3
     return sycl::uint2(div_val, mod_val);
 }
 
+static __dpct_inline__ int ggml_sycl_dp4a(const int a, const int b, int c) {
+    return dpct::dp4a(a, b, c);
+}
+
+static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint32_t) x << 23;
+    }
+
+    float result;
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+}
+
 
 #endif // GGML_SYCL_COMMON_HPP

+ 15 - 0
ggml/src/ggml-sycl/convert.cpp

@@ -472,6 +472,16 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
       }
 }
 
+template <typename dst_t>
+static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
+        [=](sycl::nd_item<3> item_ct1) {
+            dequantize_block_mxfp4(vx, y, item_ct1);
+        });
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
                           const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
@@ -518,6 +528,7 @@ static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct
     convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
 }
 
+
 to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
     switch (type) {
         case GGML_TYPE_Q4_0:
@@ -571,6 +582,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
             return dequantize_row_iq4_xs_sycl;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_sycl;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_sycl;
         case GGML_TYPE_F32:
             return convert_unary_sycl<float>;
 #ifdef GGML_SYCL_HAS_BF16
@@ -636,6 +649,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
             return dequantize_row_iq4_xs_sycl;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_sycl;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_sycl;
         case GGML_TYPE_F16:
             return convert_unary_sycl<sycl::half>;
 #ifdef GGML_SYCL_HAS_BF16

+ 18 - 0
ggml/src/ggml-sycl/dequantize.hpp

@@ -819,5 +819,23 @@ dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
     }
 }
 
+template<typename dst_t>
+static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                   const sycl::nd_item<3> &item_ct1) {
+    // auto                item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int64_t       i        = item_ct1.get_group(2);
+    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+    const int64_t    tid = item_ct1.get_local_id(2);
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = ggml_sycl_e8m0_to_fp32(x[ib].e);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;
+    }
+}
 
 #endif // GGML_SYCL_DEQUANTIZE_HPP

+ 56 - 3
ggml/src/ggml-sycl/dpct/helper.hpp

@@ -1860,10 +1860,31 @@ namespace dpct
                                            : id);
     }
 
+    template <typename T1, typename T2>
+    using dot_product_acc_t = std::conditional_t<
+        std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
+        uint32_t,
+        int32_t>;
+
+    template <typename T>
+    sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
+      return sycl::vec<T, 1>(val)
+          .template as<sycl::vec<
+              std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
+              4>>()
+          .template convert<T>();
+    }
+
     template <typename T1, typename T2, typename T3>
-    inline auto dp4a(T1 a, T2 b, T3 c)
-    {
-        return syclcompat::dp4a(a, b, c);
+    inline auto dp4a(T1 a, T2 b, T3 c) {
+      dot_product_acc_t<T1, T2> res = c;
+      auto va = extract_and_sign_or_zero_extend4(a);
+      auto vb = extract_and_sign_or_zero_extend4(b);
+      res += va[0] * vb[0];
+      res += va[1] * vb[1];
+      res += va[2] * vb[2];
+      res += va[3] * vb[3];
+      return res;
     }
 
     struct sub_sat
@@ -2972,6 +2993,38 @@ namespace dpct
     atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
     }
 
+    inline unsigned int byte_level_permute(
+        unsigned int a, unsigned int b, unsigned int s) {
+      unsigned int ret;
+      ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
+            (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
+             << 8) |
+            (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
+             << 16) |
+            (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
+             << 24);
+      return ret;
+    }
+
+    inline uint32_t byte_level_permute_custom(
+        uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
+      constexpr uint16_t lookup[6][4] = {
+          {0x3210, 0x4321, 0x5432, 0x6543},  // Forward 4-byte extract
+          {0x5670, 0x6701, 0x7012, 0x0123},  // Backward 4-byte extract
+          {0x0000, 0x1111, 0x2222, 0x3333},  // Replicate 8-bit values
+          {0x3210, 0x3211, 0x3222, 0x3333},  // Edge clamp left
+          {0x0000, 0x1110, 0x2210, 0x3210},  // Edge clamp right
+          {0x1010, 0x3232, 0x1010, 0x3232}   // Replicate 16-bit values
+      };
+
+      if (mode >= 1 && mode <= 6) {
+        return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
+      } else if (!mode) {
+        return byte_level_permute(low32, high32, sel);
+      }
+      return 0;
+    }
+
 } // COPY from DPCT head files
 
 #endif // GGML_SYCL_DPCT_HELPER_HPP

+ 97 - 0
ggml/src/ggml-sycl/element_wise.cpp

@@ -911,6 +911,98 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
         });
 }
 
+__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
+    x = sycl::fmin(x, limit);
+    g = sycl::fmax(sycl::fmin(g, limit), -limit);
+
+    float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
+    out_glu = out_glu * (1.0f + g);
+    return out_glu;
+}
+
+
+template <typename T>
+static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
+                              const int64_t n, const int64_t o0, const int64_t o1,
+                              float alpha, float limit, sycl::nd_item<3> item_ct1) {
+    const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+
+    const int64_t j0 = (i / n) * o0 + (i % n);
+    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+    float xi = x[j0];
+    float gi = g[j1];
+
+    dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
+}
+
+template <typename T>
+static void swiglu_oai_sycl(const T *       x,
+                            const T *       g,
+                            T *             dst,
+                            const int64_t   k,
+                            const int64_t   n,
+                            const int64_t   o0,
+                            const int64_t   o1,
+                            const float     alpha,
+                            const float     limit,
+                            dpct::queue_ptr stream) {
+    const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
+                                           sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
+                         });
+}
+
+void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    void * src0_d = src0->data;
+    void * src1_d = src1 ? src1->data : src0->data;
+    const int64_t src0_o = src0->nb[1];
+    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+    void * dst_d = dst->data;
+    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    dpct::queue_ptr     stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+        GGML_ASSERT(src1->ne[0] == nc);
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
+    float * src0_p = (float *) src0_d;
+    float * src1_p = (float *) src1_d;
+
+    if (!src1) {
+        src0_p += swapped ? nc : 0;
+        src1_p += swapped ? 0 : nc;
+    }
+
+    swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
 static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
         [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
@@ -1070,6 +1162,11 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_swiglu(ctx, dst);
 }
 
+void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_swiglu_oai(ctx, dst);
+}
+
 void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_geglu_erf(ctx, dst);

+ 4 - 0
ggml/src/ggml-sycl/element_wise.hpp

@@ -5,6 +5,8 @@
 #include "ggml.h"
 #include <limits> // For std::numeric_limits
 
+#define SYCL_GLU_BLOCK_SIZE 256
+
 template <typename T>
 T neg_infinity() {
     return -std::numeric_limits<T>::infinity();
@@ -41,6 +43,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
+void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
 void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

+ 17 - 6
ggml/src/ggml-sycl/ggml-sycl.cpp

@@ -39,6 +39,7 @@
 #include "ggml-impl.h"
 #include "ggml-backend-impl.h"
 
+#include "ggml-sycl/add-id.hpp"
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml-sycl/element_wise.hpp"
@@ -3313,6 +3314,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
     bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
+
     // mmvq and mmq need the __dp4a instruction which is available for gen12+
     // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
     use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
@@ -3320,7 +3322,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
     use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
 #endif // SYCL_USE_XMX
 
-
     // mmvq path is faster in the CUDA backend.
     if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
         // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
@@ -3711,6 +3712,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ADD1: // TODO: more efficient implementation
             ggml_sycl_add(ctx, dst);
             break;
+        case GGML_OP_ADD_ID:
+            ggml_sycl_add_id(ctx, dst);
+            break;
         case GGML_OP_SUB:
             ggml_sycl_sub(ctx, dst);
             break;
@@ -3803,6 +3807,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_GLU_OP_SWIGLU:
                     ggml_sycl_swiglu(ctx, dst);
                     break;
+                case GGML_GLU_OP_SWIGLU_OAI:
+                    ggml_sycl_swiglu_oai(ctx, dst);
+                    break;
                 case GGML_GLU_OP_GEGLU_ERF:
                     ggml_sycl_geglu_erf(ctx, dst);
                     break;
@@ -4397,6 +4404,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]);
@@ -4424,15 +4432,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                     }
                 }
                 ggml_type src0_type = op->src[0]->type;
-                if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
-                    // TODO: support MXFP4
+                if (src0_type == GGML_TYPE_BF16 ) {
+                    // TODO: support GGML_TYPE_BF16
                     // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
                     return false;
                 }
+
                 // TODO: The configuration below needs more work to be supported with oneDNN
-                if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
-                    return false;
+                if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
+                    a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
+                  return false;
                 }
+
                 // TODO: This specific configuration can fail with oneDNN and needs more debugging
                 if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
                     a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
@@ -4553,9 +4564,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
-            return true;
         case GGML_OP_ADD:
         case GGML_OP_ADD1:
+        case GGML_OP_ADD_ID:
         case GGML_OP_SUB:
         case GGML_OP_COUNT_EQUAL:
         case GGML_OP_MUL:

+ 22 - 0
ggml/src/ggml-sycl/mmvq.cpp

@@ -595,6 +595,25 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
     }
 }
 
+static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
+                                        dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_MXFP4 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+
+    {
+        stream->submit([&](sycl::handler & cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                 mul_mat_vec_q<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
+                                     vx, vy, dst, ncols, nrows, item_ct1);
+                             });
+        });
+    }
+}
+
+
 static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
                                        float *dst, const int ncols,
                                        const int nrows,
@@ -1123,6 +1142,9 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
             case GGML_TYPE_IQ4_XS:
                 mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
                 break;
+            case GGML_TYPE_MXFP4:
+                mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+                break;
             default:
                 GGML_ABORT("fatal error");
         }

+ 5 - 5
ggml/src/ggml-sycl/pad.cpp

@@ -14,10 +14,10 @@
 #include "pad.hpp"
 
 static void pad_f32(const float * src, float * dst,
-                               const int lp0, const int rp0, const int lp1, const int rp1,
-                               const int lp2, const int rp2, const int lp3, const int rp3,
-                               const int ne0, const int ne1, const int ne2, const int ne3) {
-    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+                    const int lp0, const int rp0, const int lp1, const int rp1,
+                    const int lp2, const int rp2, const int lp3, const int rp3,
+                    const int ne0, const int ne1, const int ne2, const int ne3,
+                    sycl::nd_item<3> item_ct1) {
     int i0 = item_ct1.get_local_id(2) +
              item_ct1.get_group(2) * item_ct1.get_local_range(2);
     int i1 = item_ct1.get_group(1);
@@ -63,7 +63,7 @@ static void pad_f32_sycl(const float *src, float *dst, const int lp0,
                           sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
         [=](sycl::nd_item<3> item_ct1) {
             pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
-                    ne2, ne3);
+                    ne2, ne3, item_ct1);
         });
 }
 

+ 1 - 1
ggml/src/ggml-sycl/ssm_conv.cpp

@@ -88,7 +88,7 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
 
-    GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
 
     const int src_stride_inner = ncs;
     const int src_stride_seq   = ncs * d_inner;

+ 58 - 0
ggml/src/ggml-sycl/vecdotq.hpp

@@ -20,6 +20,18 @@
 typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
                                   const int & iqs);
 
+static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
+    const uint8_t * x8 = (const uint8_t *) x;
+
+    int x32  = x8[4*i32 + 0] <<  0;
+    x32     |= x8[4*i32 + 1] <<  8;
+    x32     |= x8[4*i32 + 2] << 16;
+    x32     |= x8[4*i32 + 3] << 24;
+
+    return x32;
+}
+
+
 static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
   const uint16_t* x16 =
       (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
@@ -75,6 +87,28 @@ static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
     val2 = v1 | (v2 << 16);
 }
 
+static __dpct_inline__ sycl::int2 get_int_from_table_16(
+    const int& q4, const int8_t* table) {
+  const uint32_t* table32 = (const uint32_t*)table;
+  uint32_t tmp[2];
+  const uint32_t low_high_selection_indices =
+      (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+  for (uint32_t i = 0; i < 2; ++i) {
+    const uint32_t shift = 16 * i;
+
+    const uint32_t low =
+        dpct::byte_level_permute(table32[0], table32[1], q4 >> shift);
+    const uint32_t high =
+        dpct::byte_level_permute(table32[2], table32[3], q4 >> shift);
+    tmp[i] = dpct::byte_level_permute(
+        low, high, low_high_selection_indices >> shift);
+  }
+  return sycl::int2(
+      dpct::byte_level_permute(tmp[0], tmp[1], 0x6420),
+      dpct::byte_level_permute(tmp[0], tmp[1], 0x7531));
+}
+
 #define VDR_Q2_K_Q8_1_MMVQ 1
 
 // contiguous v/x values
@@ -685,6 +719,30 @@ vec_dot_q4_1_q8_1(const void *__restrict__ vbq,
     return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
 }
 
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ  4
+
+static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
+                                                const block_q8_1 * __restrict__ bq8_1,
+                                                const int & iqs) {
+    const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq;
+
+    const int * q8 = (const int *) bq8_1->qs + iqs;
+
+    int sumi = 0;
+#pragma unroll
+    for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+        const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+        const sycl::int2 v      = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+        sumi = ggml_sycl_dp4a(v.x(), q8[l + 0], sumi);
+        sumi = ggml_sycl_dp4a(v.y(), q8[l + 4], sumi);
+    }
+
+    const float d = ggml_sycl_e8m0_to_fp32(bq4->e) * 0.5f * (bq8_1->ds)[0];
+    return d * sumi;
+}
+
+
 static __dpct_inline__ float
 vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {

Някои файлове не бяха показани, защото твърде много файлове са промени