Skip to content

Commit

Permalink
Fix duplicate & unnecessary flash attention warnings (#28557)
Browse files Browse the repository at this point in the history
* fix duplicate & unnecessary flash warnings

* trigger ci

* warning_once

* if/else order

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
fxmarty and Your Name authored Jan 26, 2024
1 parent 142ce68 commit 8eb74c1
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,10 @@ def _from_config(cls, config, **kwargs):
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config._attn_implementation = kwargs.pop("attn_implementation", None)
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
config,
use_flash_attention_2=use_flash_attention_2,
check_device_map=False,
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled():
Expand Down Expand Up @@ -1396,7 +1399,8 @@ def _autoset_attn_implementation(
elif requested_attn_implementation in [None, "sdpa"]:
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config, hard_check_only=False if requested_attn_implementation is None else True
config,
hard_check_only=False if requested_attn_implementation is None else True,
)
else:
config._attn_implementation = "eager"
Expand Down Expand Up @@ -1503,20 +1507,21 @@ def _check_and_enable_flash_attn_2(
)

if torch_dtype is None:
logger.warning(
logger.warning_once(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
"No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
logger.warning_once(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
)

# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning(
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
Expand Down

0 comments on commit 8eb74c1

Please sign in to comment.