diff options
author | Erik Scholz <Green-Sky@users.noreply.github.com> | 2023-09-25 13:48:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-25 13:48:30 +0200 |
commit | a98b1633d5a94d0aa84c7c16e1f8df5ac21fc850 (patch) | |
tree | 9b6f047c98687e79c03ca6c494e67e6d0af25f04 | |
parent | c091cdfb24621710c617ea85c92fcd347d0bf340 (diff) |
nix : add cuda, use a symlinked toolkit for cmake (#3202)
-rw-r--r-- | flake.nix | 21 |
1 files changed, 21 insertions, 0 deletions
@@ -35,6 +35,20 @@ ); pkgs = import nixpkgs { inherit system; }; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; + cudatoolkit_joined = with pkgs; symlinkJoin { + # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit + # see https://github.com/NixOS/nixpkgs/issues/224291 + # copied from jaxlib + name = "${cudaPackages.cudatoolkit.name}-merged"; + paths = [ + cudaPackages.cudatoolkit.lib + cudaPackages.cudatoolkit.out + ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [ + # for some reason some of the required libs are in the targets/x86_64-linux + # directory; not sure why but this works around it + "${cudaPackages.cudatoolkit}/targets/${system}" + ]; + }; llama-python = pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); postPatch = '' @@ -70,6 +84,13 @@ "-DLLAMA_CLBLAST=ON" ]; }; + packages.cuda = pkgs.stdenv.mkDerivation { + inherit name src meta postPatch nativeBuildInputs postInstall; + buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ]; + cmakeFlags = cmakeFlags ++ [ + "-DLLAMA_CUBLAS=ON" + ]; + }; packages.rocm = pkgs.stdenv.mkDerivation { inherit name src meta postPatch nativeBuildInputs postInstall; buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ]; |