flake.nix 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. {
  2. inputs = {
  3. nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
  4. flake-utils.url = "github:numtide/flake-utils";
  5. };
  6. outputs = { self, nixpkgs, flake-utils }:
  7. flake-utils.lib.eachDefaultSystem (system:
  8. let
  9. name = "llama.cpp";
  10. src = ./.;
  11. meta.mainProgram = "llama";
  12. inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
  13. buildInputs = with pkgs; [ openmpi ];
  14. osSpecific = with pkgs; buildInputs ++ (
  15. if isAarch64 && isDarwin then
  16. with pkgs.darwin.apple_sdk_11_0.frameworks; [
  17. Accelerate
  18. MetalKit
  19. ]
  20. else if isAarch32 && isDarwin then
  21. with pkgs.darwin.apple_sdk.frameworks; [
  22. Accelerate
  23. CoreGraphics
  24. CoreVideo
  25. ]
  26. else if isDarwin then
  27. with pkgs.darwin.apple_sdk.frameworks; [
  28. Accelerate
  29. CoreGraphics
  30. CoreVideo
  31. ]
  32. else
  33. with pkgs; [ openblas ]
  34. );
  35. pkgs = import nixpkgs { inherit system; };
  36. nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
  37. cudatoolkit_joined = with pkgs; symlinkJoin {
  38. # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
  39. # see https://github.com/NixOS/nixpkgs/issues/224291
  40. # copied from jaxlib
  41. name = "${cudaPackages.cudatoolkit.name}-merged";
  42. paths = [
  43. cudaPackages.cudatoolkit.lib
  44. cudaPackages.cudatoolkit.out
  45. ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
  46. # for some reason some of the required libs are in the targets/x86_64-linux
  47. # directory; not sure why but this works around it
  48. "${cudaPackages.cudatoolkit}/targets/${system}"
  49. ];
  50. };
  51. llama-python =
  52. pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
  53. # TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
  54. llama-python-extra =
  55. pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece torchWithoutCuda transformers ]);
  56. postPatch = ''
  57. substituteInPlace ./ggml-metal.m \
  58. --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
  59. substituteInPlace ./*.py --replace '/usr/bin/env python' '${llama-python}/bin/python'
  60. '';
  61. postInstall = ''
  62. mv $out/bin/main $out/bin/llama
  63. mv $out/bin/server $out/bin/llama-server
  64. mkdir -p $out/include
  65. cp ${src}/llama.h $out/include/
  66. '';
  67. cmakeFlags = [ "-DLLAMA_NATIVE=OFF" "-DLLAMA_BUILD_SERVER=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
  68. in
  69. {
  70. packages.default = pkgs.stdenv.mkDerivation {
  71. inherit name src meta postPatch nativeBuildInputs postInstall;
  72. buildInputs = osSpecific;
  73. cmakeFlags = cmakeFlags
  74. ++ (if isAarch64 && isDarwin then [
  75. "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
  76. "-DLLAMA_METAL=ON"
  77. ] else [
  78. "-DLLAMA_BLAS=ON"
  79. "-DLLAMA_BLAS_VENDOR=OpenBLAS"
  80. ]);
  81. };
  82. packages.opencl = pkgs.stdenv.mkDerivation {
  83. inherit name src meta postPatch nativeBuildInputs postInstall;
  84. buildInputs = with pkgs; buildInputs ++ [ clblast ];
  85. cmakeFlags = cmakeFlags ++ [
  86. "-DLLAMA_CLBLAST=ON"
  87. ];
  88. };
  89. packages.cuda = pkgs.stdenv.mkDerivation {
  90. inherit name src meta postPatch nativeBuildInputs postInstall;
  91. buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
  92. cmakeFlags = cmakeFlags ++ [
  93. "-DLLAMA_CUBLAS=ON"
  94. ];
  95. };
  96. packages.rocm = pkgs.stdenv.mkDerivation {
  97. inherit name src meta postPatch nativeBuildInputs postInstall;
  98. buildInputs = with pkgs.rocmPackages; buildInputs ++ [ clr hipblas rocblas ];
  99. cmakeFlags = cmakeFlags ++ [
  100. "-DLLAMA_HIPBLAS=1"
  101. "-DCMAKE_C_COMPILER=hipcc"
  102. "-DCMAKE_CXX_COMPILER=hipcc"
  103. # Build all targets supported by rocBLAS. When updating search for TARGET_LIST_ROCM
  104. # in github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/CMakeLists.txt
  105. # and select the line that matches the current nixpkgs version of rocBLAS.
  106. "-DAMDGPU_TARGETS=gfx803;gfx900;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102"
  107. ];
  108. };
  109. apps.llama-server = {
  110. type = "app";
  111. program = "${self.packages.${system}.default}/bin/llama-server";
  112. };
  113. apps.llama-embedding = {
  114. type = "app";
  115. program = "${self.packages.${system}.default}/bin/embedding";
  116. };
  117. apps.llama = {
  118. type = "app";
  119. program = "${self.packages.${system}.default}/bin/llama";
  120. };
  121. apps.quantize = {
  122. type = "app";
  123. program = "${self.packages.${system}.default}/bin/quantize";
  124. };
  125. apps.train-text-from-scratch = {
  126. type = "app";
  127. program = "${self.packages.${system}.default}/bin/train-text-from-scratch";
  128. };
  129. apps.default = self.apps.${system}.llama;
  130. devShells.default = pkgs.mkShell {
  131. buildInputs = [ llama-python ];
  132. packages = nativeBuildInputs ++ osSpecific;
  133. };
  134. devShells.extra = pkgs.mkShell {
  135. buildInputs = [ llama-python-extra ];
  136. packages = nativeBuildInputs ++ osSpecific;
  137. };
  138. });
  139. }