diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index aab02b8ea..64083652d 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -139,7 +139,9 @@ def forward(self, features: Dict[str, Tensor]): # attention_mask shape: (bs, seq_len) # Get shape [bs] indices of the last token (i.e. the last token for each batch item) # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 - gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 # Shape [bs] + # Any sequence where min == 1, we use the entire sequence length since argmin = 0 + values, indices = torch.min(attention_mask, 1, keepdim = False) + gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs] # There are empty sequences, where the index would become -1 which will crash gather_indices = torch.clamp(gather_indices, min=0)