Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LED documentation #17181

Merged
merged 4 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/transformers/models/led/configuration_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,17 @@ class LEDConfig(PretrainedConfig):
Example:

```python
>>> from transformers import LEDModel, LEDConfig

```
>>> # Initializing a LED allenai/led-base-16384 style configuration
>>> configuration = LEDConfig()

>>> from transformers import LEDModel, LEDConfig
>>> # Initializing a model from the allenai/led-base-16384 style configuration
>>> model = LEDModel(configuration)

>>> # Initializing a LED allenai/led-base-16384 style configuration >>> configuration = LEDConfig()

>>> # Initializing a model from the allenai/led-base-16384 style configuration >>> model =
LEDModel(configuration)

>>> # Accessing the model configuration >>> configuration = model.config
"""
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "led"
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def forward(
"""
residual = hidden_states

# Self Attention
# Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
Expand Down Expand Up @@ -1437,12 +1437,12 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):


LED_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
This model inherits from [`PreTrainedModel`]. See the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
Use it as a regular PyTorch Module and refer to the PyTorch documentation for general usage
and behavior.

Parameters:
Expand Down Expand Up @@ -1595,7 +1595,7 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):

class LEDEncoder(LEDPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`LEDEncoderLayer`].

Args:
Expand Down Expand Up @@ -1643,7 +1643,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non
self.post_init()

def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def call(
"""
residual = hidden_states

# Self Attention
# Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
Expand Down Expand Up @@ -1612,7 +1612,7 @@ class TFLEDSeq2SeqLMOutput(ModelOutput):
class TFLEDEncoder(tf.keras.layers.Layer):
config_class = LEDConfig
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`TFLEDEncoderLayer`].

Args:
Expand Down