ggml-cuda.h 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #include <cublas_v2.h>
  2. #include <cuda_runtime.h>
  3. #ifdef __cplusplus
  4. extern "C" {
  5. #endif
  6. #define CUDA_CHECK(err) \
  7. do { \
  8. cudaError_t err_ = (err); \
  9. if (err_ != cudaSuccess) { \
  10. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  11. cudaGetErrorString(err_)); \
  12. exit(1); \
  13. } \
  14. } while (0)
  15. #define CUBLAS_CHECK(err) \
  16. do { \
  17. cublasStatus_t err_ = (err); \
  18. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  19. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  20. exit(1); \
  21. } \
  22. } while (0)
  23. extern cublasHandle_t g_cublasH;
  24. extern cudaStream_t g_cudaStream;
  25. void ggml_init_cublas(void);
  26. void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
  27. void ggml_cuda_pool_free(void * ptr, size_t size);
  28. void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  29. void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  30. void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  31. void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  32. void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  33. void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  34. #ifdef __cplusplus
  35. }
  36. #endif