Parcourir la source

nix : add cuda, use a symlinked toolkit for cmake (#3202)

Erik Scholz il y a 2 ans
Parent
commit
a98b1633d5
1 fichiers modifiés avec 21 ajouts et 0 suppressions
  1. 21 0
      flake.nix

+ 21 - 0
flake.nix

@@ -35,6 +35,20 @@
         );
         pkgs = import nixpkgs { inherit system; };
         nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
+        cudatoolkit_joined = with pkgs; symlinkJoin {
+          # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
+          # see https://github.com/NixOS/nixpkgs/issues/224291
+          # copied from jaxlib
+          name = "${cudaPackages.cudatoolkit.name}-merged";
+          paths = [
+            cudaPackages.cudatoolkit.lib
+            cudaPackages.cudatoolkit.out
+          ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
+            # for some reason some of the required libs are in the targets/x86_64-linux
+            # directory; not sure why but this works around it
+            "${cudaPackages.cudatoolkit}/targets/${system}"
+          ];
+        };
         llama-python =
           pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
         postPatch = ''
@@ -70,6 +84,13 @@
             "-DLLAMA_CLBLAST=ON"
           ];
         };
+        packages.cuda = pkgs.stdenv.mkDerivation {
+          inherit name src meta postPatch nativeBuildInputs postInstall;
+          buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
+          cmakeFlags = cmakeFlags ++ [
+            "-DLLAMA_CUBLAS=ON"
+          ];
+        };
         packages.rocm = pkgs.stdenv.mkDerivation {
           inherit name src meta postPatch nativeBuildInputs postInstall;
           buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];