|
@@ -35,6 +35,20 @@
|
|
|
);
|
|
);
|
|
|
pkgs = import nixpkgs { inherit system; };
|
|
pkgs = import nixpkgs { inherit system; };
|
|
|
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
|
|
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
|
|
|
|
|
+ cudatoolkit_joined = with pkgs; symlinkJoin {
|
|
|
|
|
+ # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
|
|
|
|
|
+ # see https://github.com/NixOS/nixpkgs/issues/224291
|
|
|
|
|
+ # copied from jaxlib
|
|
|
|
|
+ name = "${cudaPackages.cudatoolkit.name}-merged";
|
|
|
|
|
+ paths = [
|
|
|
|
|
+ cudaPackages.cudatoolkit.lib
|
|
|
|
|
+ cudaPackages.cudatoolkit.out
|
|
|
|
|
+ ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
|
|
|
|
|
+ # for some reason some of the required libs are in the targets/x86_64-linux
|
|
|
|
|
+ # directory; not sure why but this works around it
|
|
|
|
|
+ "${cudaPackages.cudatoolkit}/targets/${system}"
|
|
|
|
|
+ ];
|
|
|
|
|
+ };
|
|
|
llama-python =
|
|
llama-python =
|
|
|
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
|
|
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
|
|
|
postPatch = ''
|
|
postPatch = ''
|
|
@@ -70,6 +84,13 @@
|
|
|
"-DLLAMA_CLBLAST=ON"
|
|
"-DLLAMA_CLBLAST=ON"
|
|
|
];
|
|
];
|
|
|
};
|
|
};
|
|
|
|
|
+ packages.cuda = pkgs.stdenv.mkDerivation {
|
|
|
|
|
+ inherit name src meta postPatch nativeBuildInputs postInstall;
|
|
|
|
|
+ buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
|
|
|
|
|
+ cmakeFlags = cmakeFlags ++ [
|
|
|
|
|
+ "-DLLAMA_CUBLAS=ON"
|
|
|
|
|
+ ];
|
|
|
|
|
+ };
|
|
|
packages.rocm = pkgs.stdenv.mkDerivation {
|
|
packages.rocm = pkgs.stdenv.mkDerivation {
|
|
|
inherit name src meta postPatch nativeBuildInputs postInstall;
|
|
inherit name src meta postPatch nativeBuildInputs postInstall;
|
|
|
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
|
|
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
|