ggml-cuda.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #include <cublas_v2.h>
  2. #include <cuda_runtime.h>
  3. #include "ggml.h"
  4. #ifdef __cplusplus
  5. extern "C" {
  6. #endif
  7. #define CUDA_CHECK(err) \
  8. do { \
  9. cudaError_t err_ = (err); \
  10. if (err_ != cudaSuccess) { \
  11. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  12. cudaGetErrorString(err_)); \
  13. exit(1); \
  14. } \
  15. } while (0)
  16. #define CUBLAS_CHECK(err) \
  17. do { \
  18. cublasStatus_t err_ = (err); \
  19. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  20. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  21. exit(1); \
  22. } \
  23. } while (0)
  24. extern cublasHandle_t g_cublasH;
  25. extern cudaStream_t g_cudaStream;
  26. extern cudaStream_t g_cudaStream2;
  27. extern cudaEvent_t g_cudaEvent;
  28. void ggml_init_cublas(void);
  29. void * ggml_cuda_host_malloc(size_t size);
  30. void ggml_cuda_host_free(void * ptr);
  31. void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
  32. void ggml_cuda_pool_free(void * ptr, size_t size);
  33. void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  34. void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  35. void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  36. void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  37. void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  38. void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
  39. cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
  40. typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  41. dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
  42. #ifdef __cplusplus
  43. }
  44. #endif