summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorErik Scholz <Green-Sky@users.noreply.github.com>2023-09-25 13:48:30 +0200
committerGitHub <noreply@github.com>2023-09-25 13:48:30 +0200
commita98b1633d5a94d0aa84c7c16e1f8df5ac21fc850 (patch)
tree9b6f047c98687e79c03ca6c494e67e6d0af25f04
parentc091cdfb24621710c617ea85c92fcd347d0bf340 (diff)
nix : add cuda, use a symlinked toolkit for cmake (#3202)
-rw-r--r--flake.nix21
1 files changed, 21 insertions, 0 deletions
diff --git a/flake.nix b/flake.nix
index 7723357a..433d3d94 100644
--- a/flake.nix
+++ b/flake.nix
@@ -35,6 +35,20 @@
);
pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
+ cudatoolkit_joined = with pkgs; symlinkJoin {
+ # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
+ # see https://github.com/NixOS/nixpkgs/issues/224291
+ # copied from jaxlib
+ name = "${cudaPackages.cudatoolkit.name}-merged";
+ paths = [
+ cudaPackages.cudatoolkit.lib
+ cudaPackages.cudatoolkit.out
+ ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
+ # for some reason some of the required libs are in the targets/x86_64-linux
+ # directory; not sure why but this works around it
+ "${cudaPackages.cudatoolkit}/targets/${system}"
+ ];
+ };
llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = ''
@@ -70,6 +84,13 @@
"-DLLAMA_CLBLAST=ON"
];
};
+ packages.cuda = pkgs.stdenv.mkDerivation {
+ inherit name src meta postPatch nativeBuildInputs postInstall;
+ buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
+ cmakeFlags = cmakeFlags ++ [
+ "-DLLAMA_CUBLAS=ON"
+ ];
+ };
packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];