Skip to content

Commit

Permalink
'Refactored by Sourcery'
Browse files Browse the repository at this point in the history
  • Loading branch information
SourceryAI committed Oct 30, 2023
1 parent 8ddbfe9 commit 7511de6
Show file tree
Hide file tree
Showing 59 changed files with 595 additions and 579 deletions.
22 changes: 10 additions & 12 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,12 @@ class DataTrainingArguments:
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."


class TrainState(train_state.TrainState):
Expand All @@ -198,9 +197,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf

for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}

yield batch
yield {k: np.array(v) for k, v in batch.items()}


def write_train_metric(summary_writer, train_metrics, train_time, step):
Expand Down Expand Up @@ -228,8 +225,9 @@ def create_learning_rate_fn(
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
return optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)


def main():
Expand Down
25 changes: 12 additions & 13 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,12 @@ class DataTrainingArguments:
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."


@flax.struct.dataclass
Expand Down Expand Up @@ -251,8 +250,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
return np.split(samples_idx, sections_split)


def write_train_metric(summary_writer, train_metrics, train_time, step):
Expand Down Expand Up @@ -459,12 +457,13 @@ def group_texts(examples):
# customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
return {
k: [
t[i : i + max_seq_length]
for i in range(0, total_length, max_seq_length)
]
for k, t in concatenated_examples.items()
}
return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
Expand Down
32 changes: 18 additions & 14 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ class DataTrainingArguments:
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."


def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
Expand Down Expand Up @@ -254,7 +253,12 @@ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarra
input_ids = batch["input_ids"]
batch_size, expandend_input_length = input_ids.shape

mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
mask_indices = np.asarray(
[
self.random_spans_noise_mask(expandend_input_length)
for _ in range(batch_size)
]
)
labels_mask = ~mask_indices

input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
Expand Down Expand Up @@ -381,8 +385,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
return np.split(samples_idx, sections_split)


def write_train_metric(summary_writer, train_metrics, train_time, step):
Expand Down Expand Up @@ -569,12 +572,13 @@ def group_texts(examples):
# customize this part to your needs.
if total_length >= expanded_inputs_length:
total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
# Split by chunks of max_len.
result = {
k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
return {
k: [
t[i : i + expanded_inputs_length]
for i in range(0, total_length, expanded_inputs_length)
]
for k, t in concatenated_examples.items()
}
return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
Expand Down
32 changes: 14 additions & 18 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,15 @@ def __post_init__(self):
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


# endregion
Expand Down Expand Up @@ -292,8 +291,9 @@ def create_learning_rate_fn(
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
return optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)


# endregion
Expand All @@ -309,9 +309,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
for perm in perms:
batch = dataset[perm]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


# endregion
Expand All @@ -322,9 +320,7 @@ def eval_data_collator(dataset: Dataset, batch_size: int):
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


# endregion
Expand Down
4 changes: 3 additions & 1 deletion examples/flax/question-answering/utils_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if version_2_with_negative and all(
p["offsets"] != (0, 0) for p in predictions
):
predictions.append(min_null_prediction)

# Use the offsets to gather the answer text in the original context.
Expand Down
22 changes: 10 additions & 12 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,12 @@ class DataTrainingArguments:
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length

Expand Down Expand Up @@ -260,9 +259,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}

batch = shard(batch)

yield batch
yield shard(batch)


def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
Expand All @@ -288,8 +285,9 @@ def create_learning_rate_fn(
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
return optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)


def main():
Expand Down
27 changes: 11 additions & 16 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,14 @@ def parse_args():
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()

# Sanity checks
if args.task_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a task name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."

if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
Expand Down Expand Up @@ -230,8 +228,9 @@ def create_learning_rate_fn(
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
return optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)


def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
Expand All @@ -244,19 +243,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
for perm in perms:
batch = dataset[perm]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


def main():
Expand Down
33 changes: 14 additions & 19 deletions examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,12 @@ class DataTrainingArguments:
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
self.task_name = self.task_name.lower()


Expand Down Expand Up @@ -250,8 +249,9 @@ def create_learning_rate_fn(
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
return optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)


def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
Expand All @@ -264,19 +264,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
for perm in perms:
batch = dataset[perm]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)

yield batch
yield shard(batch)


def main():
Expand Down Expand Up @@ -370,8 +366,7 @@ def get_label_list(labels):
unique_labels = set()
for label in labels:
unique_labels = unique_labels | set(label)
label_list = list(unique_labels)
label_list.sort()
label_list = sorted(unique_labels)
return label_list

if isinstance(features[label_column_name].feature, ClassLabel):
Expand Down Expand Up @@ -625,8 +620,8 @@ def compute_metrics():
):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
predictions = np.array([pred for pred in chain(*predictions)])
labels = np.array([label for label in chain(*labels)])
predictions = np.array(list(chain(*predictions)))
labels = np.array(list(chain(*labels)))
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(
Expand Down
Loading

0 comments on commit 7511de6

Please sign in to comment.