diff --git a/nlpaug/model/lang_models/bert.py b/nlpaug/model/lang_models/bert.py index b7911cb..0346201 100755 --- a/nlpaug/model/lang_models/bert.py +++ b/nlpaug/model/lang_models/bert.py @@ -106,7 +106,7 @@ def predict(self, text, target_word=None, n=1): # Prediction with torch.no_grad(): - outputs = self.model(token_inputs, segment_inputs, mask_inputs) + outputs = self.model(input_ids=token_inputs, token_type_ids=segment_inputs, attention_mask=mask_inputs) target_token_logits = outputs[0][0][target_pos] # Selection