diff --git a/conda/build_pytorch.sh b/conda/build_pytorch.sh index a0bb03d8a..56bb7654c 100755 --- a/conda/build_pytorch.sh +++ b/conda/build_pytorch.sh @@ -268,7 +268,10 @@ else . ./switch_cuda_version.sh "$desired_cuda" # TODO, simplify after anaconda fixes their cudatoolkit versioning inconsistency. # see: https://github.com/conda-forge/conda-forge.github.io/issues/687#issuecomment-460086164 - if [[ "$desired_cuda" == "12.1" ]]; then + if [[ "$desired_cuda" == "12.4" ]]; then + export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.4,<12.5 # [not osx]" + export MAGMA_PACKAGE=" - magma-cuda124 # [not osx and not win]" + elif [[ "$desired_cuda" == "12.1" ]]; then export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.1,<12.2 # [not osx]" export MAGMA_PACKAGE=" - magma-cuda121 # [not osx and not win]" elif [[ "$desired_cuda" == "11.8" ]]; then diff --git a/conda/pytorch-nightly/build.sh b/conda/pytorch-nightly/build.sh index db2b7b246..97fa682aa 100755 --- a/conda/pytorch-nightly/build.sh +++ b/conda/pytorch-nightly/build.sh @@ -60,10 +60,11 @@ if [[ -n "$build_with_cuda" ]]; then TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;3.7+PTX;9.0" #for cuda 11.8 include all dynamic loading libraries DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8 /usr/local/cuda/lib64/libcusparseLt.so.0) - elif [[ $CUDA_VERSION == 12.1* ]]; then + elif [[ $CUDA_VERSION == 12.1* || $CUDA_VERSION == 12.4* ]]; then # cuda 12 does not support sm_3x TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;9.0" # for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries + # for cuda 12.4 we use cudnn 8.9 DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12 /usr/local/cuda/lib64/libcusparseLt.so.0) fi if [[ -n "$OVERRIDE_TORCH_CUDA_ARCH_LIST" ]]; then diff --git a/manywheel/build_cuda.sh b/manywheel/build_cuda.sh index 5356a2ffa..4fc1ed278 100644 --- a/manywheel/build_cuda.sh +++ b/manywheel/build_cuda.sh @@ -59,6 +59,10 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.') TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6" case ${CUDA_VERSION} in + 12.4) + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + ;; 12.1) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") @@ -131,7 +135,7 @@ if [[ $USE_CUSPARSELT == "1" ]]; then ) fi -if [[ $CUDA_VERSION == "12.1" ]]; then +if [[ $CUDA_VERSION == "12.1" || $CUDA_VERSION == "12.4" ]]; then export USE_STATIC_CUDNN=0 # Try parallelizing nvcc as well export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2"