Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the use of PreTrainedModel.device #16935

Merged
merged 2 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _prepare_attention_mask_for_generation(
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
Expand Down Expand Up @@ -532,13 +532,16 @@ def _prepare_decoder_input_ids_for_generation(
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor:

if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
if device is None:
device = self.device
Comment on lines +542 to +543
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to self.device here for a 100% backward compatible change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great thanks!

return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
Expand Down Expand Up @@ -1177,6 +1180,7 @@ def generate(
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
Expand Down Expand Up @@ -1327,7 +1331,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1367,7 +1371,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
Expand Down Expand Up @@ -1410,7 +1414,7 @@ def generate(
batch_size=batch_size,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1492,7 +1496,7 @@ def typeerror():
constraints=final_constraints,
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ def _get_resized_embeddings(

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def _get_resized_lm_head(
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)

# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
Expand Down