mudnn.cu 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #include <mutex>
  2. #include <mudnn.h>
  3. #include "mudnn.cuh"
  4. namespace mudnn = musa::dnn;
  5. // Returns a human-readable error string for mudnn::Status
  6. const char* mudnnGetErrorString(mudnn::Status err) {
  7. switch (err) {
  8. case mudnn::Status::SUCCESS:
  9. return "Success";
  10. case mudnn::Status::INVALID_PARAMETER:
  11. return "Invalid parameter";
  12. case mudnn::Status::NOT_INITIALIZED:
  13. return "Not initialized";
  14. case mudnn::Status::ALLOC_FAILED:
  15. return "Allocation failed";
  16. case mudnn::Status::NOT_SUPPORTED:
  17. return "Not supported";
  18. case mudnn::Status::INTERNAL_ERROR:
  19. return "Internal error";
  20. case mudnn::Status::ARCH_MISMATCH:
  21. return "Architecture mismatch";
  22. case mudnn::Status::EXECUTION_FAILED:
  23. return "Execution failed";
  24. default:
  25. return "Unknown mudnn status";
  26. }
  27. }
  28. // Error checking macro for MUDNN calls
  29. #define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
  30. namespace {
  31. // Thread-safe cache for mudnn::Handle objects per device
  32. std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
  33. std::mutex handle_cache_mutex;
  34. mudnn::Handle* get_cached_handle(int device_id) {
  35. std::lock_guard<std::mutex> lock(handle_cache_mutex);
  36. auto it = handle_cache.find(device_id);
  37. if (it != handle_cache.end()) {
  38. return it->second.get();
  39. }
  40. auto handle = std::make_unique<mudnn::Handle>(device_id);
  41. mudnn::Handle* handle_ptr = handle.get();
  42. handle_cache[device_id] = std::move(handle);
  43. return handle_ptr;
  44. }
  45. }
  46. // Extracts dimensions and strides from a ggml_tensor
  47. int get_ggml_dims_and_strides(const ggml_tensor* tensor,
  48. std::vector<int64_t>& dims,
  49. std::vector<int64_t>& strides) {
  50. const int ndims = ggml_n_dims(tensor);
  51. const size_t element_size = ggml_element_size(tensor);
  52. dims.resize(ndims);
  53. strides.resize(ndims);
  54. for (int i = 0; i < ndims; ++i) {
  55. dims[i] = tensor->ne[i];
  56. strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
  57. }
  58. return ndims;
  59. }
  60. // Converts ggml_type to mudnn::Tensor::Type
  61. mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
  62. switch (type) {
  63. case GGML_TYPE_F32:
  64. return mudnn::Tensor::Type::FLOAT;
  65. case GGML_TYPE_F16:
  66. return mudnn::Tensor::Type::HALF;
  67. // TODO: Add support for other types
  68. default:
  69. MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
  70. }
  71. return mudnn::Tensor::Type::FLOAT; // Default fallback
  72. }
  73. // Asynchronous memory copy using mudnn::Unary::IDENTITY
  74. musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
  75. mudnn::Tensor tensor_dst, tensor_src;
  76. MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
  77. MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
  78. std::vector<int64_t> dims, strides;
  79. const int ndims = get_ggml_dims_and_strides(src, dims, strides);
  80. MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
  81. MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
  82. MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
  83. MUDNN_CHECK(tensor_src.SetAddr(src->data));
  84. mudnn::Unary op;
  85. MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
  86. MUDNN_CHECK(op.SetAlpha(0.0f));
  87. MUDNN_CHECK(op.SetBeta(0.0f));
  88. mudnn::Handle* handle = get_cached_handle(ctx.device);
  89. MUDNN_CHECK(handle->SetStream(ctx.stream()));
  90. MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
  91. return musaSuccess;
  92. }