flake.nix 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. {
  2. inputs = {
  3. nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
  4. flake-utils.url = "github:numtide/flake-utils";
  5. };
  6. outputs = { self, nixpkgs, flake-utils }:
  7. flake-utils.lib.eachDefaultSystem (system:
  8. let
  9. name = "llama.cpp";
  10. src = ./.;
  11. meta.mainProgram = "llama";
  12. inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
  13. buildInputs = with pkgs; [ openmpi ];
  14. osSpecific = with pkgs; buildInputs ++
  15. (
  16. if isAarch64 && isDarwin then
  17. with pkgs.darwin.apple_sdk_11_0.frameworks; [
  18. Accelerate
  19. MetalKit
  20. ]
  21. else if isAarch32 && isDarwin then
  22. with pkgs.darwin.apple_sdk.frameworks; [
  23. Accelerate
  24. CoreGraphics
  25. CoreVideo
  26. ]
  27. else if isDarwin then
  28. with pkgs.darwin.apple_sdk.frameworks; [
  29. Accelerate
  30. CoreGraphics
  31. CoreVideo
  32. ]
  33. else
  34. with pkgs; [ openblas ]
  35. );
  36. pkgs = import nixpkgs { inherit system; };
  37. nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
  38. cudatoolkit_joined = with pkgs; symlinkJoin {
  39. # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
  40. # see https://github.com/NixOS/nixpkgs/issues/224291
  41. # copied from jaxlib
  42. name = "${cudaPackages.cudatoolkit.name}-merged";
  43. paths = [
  44. cudaPackages.cudatoolkit.lib
  45. cudaPackages.cudatoolkit.out
  46. ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
  47. # for some reason some of the required libs are in the targets/x86_64-linux
  48. # directory; not sure why but this works around it
  49. "${cudaPackages.cudatoolkit}/targets/${system}"
  50. ];
  51. };
  52. llama-python =
  53. pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
  54. postPatch = ''
  55. substituteInPlace ./ggml-metal.m \
  56. --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
  57. substituteInPlace ./*.py --replace '/usr/bin/env python' '${llama-python}/bin/python'
  58. '';
  59. postInstall = ''
  60. mv $out/bin/main $out/bin/llama
  61. mv $out/bin/server $out/bin/llama-server
  62. mkdir -p $out/include
  63. cp ${src}/llama.h $out/include/
  64. '';
  65. cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
  66. in
  67. {
  68. packages.default = pkgs.stdenv.mkDerivation {
  69. inherit name src meta postPatch nativeBuildInputs postInstall;
  70. buildInputs = osSpecific;
  71. cmakeFlags = cmakeFlags
  72. ++ (if isAarch64 && isDarwin then [
  73. "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
  74. "-DLLAMA_METAL=ON"
  75. ] else [
  76. "-DLLAMA_BLAS=ON"
  77. "-DLLAMA_BLAS_VENDOR=OpenBLAS"
  78. ]);
  79. };
  80. packages.opencl = pkgs.stdenv.mkDerivation {
  81. inherit name src meta postPatch nativeBuildInputs postInstall;
  82. buildInputs = with pkgs; buildInputs ++ [ clblast ];
  83. cmakeFlags = cmakeFlags ++ [
  84. "-DLLAMA_CLBLAST=ON"
  85. ];
  86. };
  87. packages.cuda = pkgs.stdenv.mkDerivation {
  88. inherit name src meta postPatch nativeBuildInputs postInstall;
  89. buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
  90. cmakeFlags = cmakeFlags ++ [
  91. "-DLLAMA_CUBLAS=ON"
  92. ];
  93. };
  94. packages.rocm = pkgs.stdenv.mkDerivation {
  95. inherit name src meta postPatch nativeBuildInputs postInstall;
  96. buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
  97. cmakeFlags = cmakeFlags ++ [
  98. "-DLLAMA_HIPBLAS=1"
  99. "-DCMAKE_C_COMPILER=hipcc"
  100. "-DCMAKE_CXX_COMPILER=hipcc"
  101. "-DCMAKE_POSITION_INDEPENDENT_CODE=ON"
  102. ];
  103. };
  104. apps.llama-server = {
  105. type = "app";
  106. program = "${self.packages.${system}.default}/bin/llama-server";
  107. };
  108. apps.llama-embedding = {
  109. type = "app";
  110. program = "${self.packages.${system}.default}/bin/embedding";
  111. };
  112. apps.llama = {
  113. type = "app";
  114. program = "${self.packages.${system}.default}/bin/llama";
  115. };
  116. apps.quantize = {
  117. type = "app";
  118. program = "${self.packages.${system}.default}/bin/quantize";
  119. };
  120. apps.train-text-from-scratch = {
  121. type = "app";
  122. program = "${self.packages.${system}.default}/bin/train-text-from-scratch";
  123. };
  124. apps.default = self.apps.${system}.llama;
  125. devShells.default = pkgs.mkShell {
  126. buildInputs = [ llama-python ];
  127. packages = nativeBuildInputs ++ osSpecific;
  128. };
  129. });
  130. }