Kaynağa Gözat

flake.nix : add rocm support and cleanup (#2808)

Tungsten842 2 yıl önce
ebeveyn
işleme
61d1a2895e
2 değiştirilmiş dosya ile 29 ekleme ve 26 silme
  1. 6 6
      flake.lock
  2. 23 20
      flake.nix

+ 6 - 6
flake.lock

@@ -5,11 +5,11 @@
         "systems": "systems"
         "systems": "systems"
       },
       },
       "locked": {
       "locked": {
-        "lastModified": 1685518550,
-        "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=",
+        "lastModified": 1692799911,
+        "narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
         "owner": "numtide",
         "owner": "numtide",
         "repo": "flake-utils",
         "repo": "flake-utils",
-        "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef",
+        "rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
         "type": "github"
         "type": "github"
       },
       },
       "original": {
       "original": {
@@ -20,11 +20,11 @@
     },
     },
     "nixpkgs": {
     "nixpkgs": {
       "locked": {
       "locked": {
-        "lastModified": 1685931219,
-        "narHash": "sha256-8EWeOZ6LKQfgAjB/USffUSELPRjw88A+xTcXnOUvO5M=",
+        "lastModified": 1692913444,
+        "narHash": "sha256-1SvMQm2DwofNxXVtNWWtIcTh7GctEVrS/Xel/mdc6iY=",
         "owner": "NixOS",
         "owner": "NixOS",
         "repo": "nixpkgs",
         "repo": "nixpkgs",
-        "rev": "7409480d5c8584a1a83c422530419efe4afb0d19",
+        "rev": "18324978d632ffc55ef1d928e81630c620f4f447",
         "type": "github"
         "type": "github"
       },
       },
       "original": {
       "original": {

+ 23 - 20
flake.nix

@@ -6,6 +6,9 @@
   outputs = { self, nixpkgs, flake-utils }:
   outputs = { self, nixpkgs, flake-utils }:
     flake-utils.lib.eachDefaultSystem (system:
     flake-utils.lib.eachDefaultSystem (system:
       let
       let
+        name = "llama.cpp";
+        src = ./.;
+        meta.mainProgram = "llama";
         inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
         inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
         buildInputs = with pkgs; [ openmpi ];
         buildInputs = with pkgs; [ openmpi ];
         osSpecific = with pkgs; buildInputs ++
         osSpecific = with pkgs; buildInputs ++
@@ -31,7 +34,7 @@
             with pkgs; [ openblas ]
             with pkgs; [ openblas ]
         );
         );
         pkgs = import nixpkgs { inherit system; };
         pkgs = import nixpkgs { inherit system; };
-        nativeBuildInputs = with pkgs; [ cmake pkgconfig ];
+        nativeBuildInputs = with pkgs; [ cmake ninja pkgconfig ];
         llama-python =
         llama-python =
           pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
           pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
         postPatch = ''
         postPatch = ''
@@ -44,35 +47,35 @@
           mv $out/bin/server $out/bin/llama-server
           mv $out/bin/server $out/bin/llama-server
         '';
         '';
         cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
         cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
-      in {
+      in
+      {
         packages.default = pkgs.stdenv.mkDerivation {
         packages.default = pkgs.stdenv.mkDerivation {
-          name = "llama.cpp";
-          src = ./.;
-          postPatch = postPatch;
-          nativeBuildInputs = nativeBuildInputs;
-          buildInputs = osSpecific;
+          inherit name src meta postPatch nativeBuildInputs buildInputs postInstall;
           cmakeFlags = cmakeFlags
           cmakeFlags = cmakeFlags
             ++ (if isAarch64 && isDarwin then [
             ++ (if isAarch64 && isDarwin then [
-              "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
-              "-DLLAMA_METAL=ON"
-            ] else [
-              "-DLLAMA_BLAS=ON"
-              "-DLLAMA_BLAS_VENDOR=OpenBLAS"
+            "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
+            "-DLLAMA_METAL=ON"
+          ] else [
+            "-DLLAMA_BLAS=ON"
+            "-DLLAMA_BLAS_VENDOR=OpenBLAS"
           ]);
           ]);
-          postInstall = postInstall;
-          meta.mainProgram = "llama";
         };
         };
         packages.opencl = pkgs.stdenv.mkDerivation {
         packages.opencl = pkgs.stdenv.mkDerivation {
-          name = "llama.cpp";
-          src = ./.;
-          postPatch = postPatch;
-          nativeBuildInputs = nativeBuildInputs;
+          inherit name src meta postPatch nativeBuildInputs postInstall;
           buildInputs = with pkgs; buildInputs ++ [ clblast ];
           buildInputs = with pkgs; buildInputs ++ [ clblast ];
           cmakeFlags = cmakeFlags ++ [
           cmakeFlags = cmakeFlags ++ [
             "-DLLAMA_CLBLAST=ON"
             "-DLLAMA_CLBLAST=ON"
           ];
           ];
-          postInstall = postInstall;
-          meta.mainProgram = "llama";
+        };
+        packages.rocm = pkgs.stdenv.mkDerivation {
+          inherit name src meta postPatch nativeBuildInputs postInstall;
+          buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
+          cmakeFlags = cmakeFlags ++ [
+            "-DLLAMA_HIPBLAS=1"
+            "-DCMAKE_C_COMPILER=hipcc"
+            "-DCMAKE_CXX_COMPILER=hipcc"
+            "-DCMAKE_POSITION_INDEPENDENT_CODE=ON"
+          ];
         };
         };
         apps.llama-server = {
         apps.llama-server = {
           type = "app";
           type = "app";