Skip to content

Commit

Permalink
Torchao version check changes/BC import of TensorCoreTiledLayout (#1812)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Oct 12, 2024
1 parent 50d3ef1 commit 7744608
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 60 deletions.
57 changes: 0 additions & 57 deletions torchtune/modules/low_precision/_utils.py

This file was deleted.

10 changes: 8 additions & 2 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from typing import Callable, Optional

from torchao.dtypes import TensorCoreTiledLayoutType
from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API

if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API:
from torchao.dtypes import TensorCoreTiledLayout
else:
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout

from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
Expand Down Expand Up @@ -88,7 +94,7 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
self.inner_k_tiles = inner_k_tiles

def quantize(self, model):
layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles)
layout_type = TensorCoreTiledLayout(self.inner_k_tiles)
quantize_fn = int4_weight_only(self.groupsize, layout_type)
quantize_(model, quantize_fn)
return model
Expand Down
10 changes: 9 additions & 1 deletion torchtune/utils/_import_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtune.utils._version import torch_version_ge
import torchao
from torchtune.utils._version import _is_fbcode, _nightly_version_ge, torch_version_ge

# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
_SUPPORTS_FLEX_ATTENTION = (
torch_version_ge("2.5.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (7, 5)
)

torchao_version = torchao.__version__

_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and (
("dev" not in torchao_version and torchao_version >= "0.6.0")
or ("dev" in torchao_version and _nightly_version_ge(torchao_version, "2024-10-10"))
)
21 changes: 21 additions & 0 deletions torchtune/utils/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from datetime import datetime

import torch


Expand All @@ -23,3 +26,21 @@ def torch_version_ge(version: str) -> bool:
True
"""
return version in torch.__version__ or torch.__version__ >= version


def _is_fbcode():
return not hasattr(torch.version, "git_version")


def _nightly_version_ge(ao_version_str: str, date: str) -> bool:
"""
Compare a torchao nightly version to a date of the form
%Y-%m-%d.
Returns True if the nightly version is greater than or equal to
the date, False otherwise
"""
ao_datetime = datetime.strptime(
ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d"
)
return ao_datetime >= datetime.strptime(date, "%Y-%m-%d")

0 comments on commit 7744608

Please sign in to comment.