outprod.cpp 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #include <sycl/sycl.hpp>
  2. #include <oneapi/mkl.hpp>
  3. #include "outprod.hpp"
  4. void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
  5. const ggml_tensor *src0 = dst->src[0];
  6. const ggml_tensor *src1 = dst->src[1];
  7. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  8. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  9. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  10. GGML_ASSERT(ggml_is_contiguous(src0));
  11. GGML_ASSERT(ggml_is_contiguous(dst));
  12. GGML_TENSOR_BINARY_OP_LOCALS
  13. // Get SYCL queue
  14. dpct::queue_ptr stream = ctx.stream();
  15. // Dimension checks
  16. GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
  17. GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
  18. GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
  19. // Get data pointers
  20. const float* src0_d = (const float*)src0->data;
  21. const float* src1_d = (const float*)src1->data;
  22. float* dst_d = (float*)dst->data;
  23. // GEMM parameters
  24. const float alpha = 1.0f;
  25. const float beta = 0.0f;
  26. // Handle transposition of src1
  27. const bool src1_T = ggml_is_transposed(src1);
  28. const oneapi::mkl::transpose src1_op =
  29. src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
  30. const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
  31. try {
  32. // Perform matrix multiplication using oneMKL GEMM
  33. #ifdef GGML_SYCL_NVIDIA
  34. oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
  35. oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
  36. ne00, src1_d, ldb, beta, dst_d, ne0);
  37. #else
  38. oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
  39. src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
  40. #endif
  41. }
  42. catch (sycl::exception const& exc) {
  43. std::cerr << exc.what() << std::endl;
  44. GGML_ASSERT(false);
  45. }
  46. }