Skip to content

Commit

Permalink
Enable CUDA 12.4 builds
Browse files Browse the repository at this point in the history
GHA results show this is needed to fix errors in pytorch/pytorch#121684
Reference: pytorch#1374
  • Loading branch information
nWEIdia committed Apr 11, 2024
1 parent 87cdc8c commit bf63b1a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
5 changes: 4 additions & 1 deletion conda/build_pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ else
. ./switch_cuda_version.sh "$desired_cuda"
# TODO, simplify after anaconda fixes their cudatoolkit versioning inconsistency.
# see: https:/conda-forge/conda-forge.github.io/issues/687#issuecomment-460086164
if [[ "$desired_cuda" == "12.1" ]]; then
if [[ "$desired_cuda" == "12.4" ]]; then
export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.4,<12.5 # [not osx]"
export MAGMA_PACKAGE=" - magma-cuda124 # [not osx and not win]"
elif [[ "$desired_cuda" == "12.1" ]]; then
export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.1,<12.2 # [not osx]"
export MAGMA_PACKAGE=" - magma-cuda121 # [not osx and not win]"
elif [[ "$desired_cuda" == "11.8" ]]; then
Expand Down
3 changes: 2 additions & 1 deletion conda/pytorch-nightly/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ if [[ -n "$build_with_cuda" ]]; then
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;3.7+PTX;9.0"
#for cuda 11.8 include all dynamic loading libraries
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8 /usr/local/cuda/lib64/libcusparseLt.so.0)
elif [[ $CUDA_VERSION == 12.1* ]]; then
elif [[ $CUDA_VERSION == 12.1* || $CUDA_VERSION == 12.4* ]]; then
# cuda 12 does not support sm_3x
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;9.0"
# for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries
# for cuda 12.4 we use cudnn 8.9
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12 /usr/local/cuda/lib64/libcusparseLt.so.0)
fi
if [[ -n "$OVERRIDE_TORCH_CUDA_ARCH_LIST" ]]; then
Expand Down
6 changes: 5 additions & 1 deletion manywheel/build_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')

TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6"
case ${CUDA_VERSION} in
12.4)
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;
12.1)
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
Expand Down Expand Up @@ -131,7 +135,7 @@ if [[ $USE_CUSPARSELT == "1" ]]; then
)
fi

if [[ $CUDA_VERSION == "12.1" ]]; then
if [[ $CUDA_VERSION == "12.1" || $CUDA_VERSION == "12.4" ]]; then
export USE_STATIC_CUDNN=0
# Try parallelizing nvcc as well
export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2"
Expand Down

0 comments on commit bf63b1a

Please sign in to comment.