Skip to content

Commit

Permalink
Merge pull request #4940 from neutrinoceros/sty/strict_zip
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Jul 27, 2024
2 parents d25b54c + 91de9ba commit bc63757
Show file tree
Hide file tree
Showing 62 changed files with 391 additions and 129 deletions.
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ repos:
types_or: [ python, pyi, jupyter ]
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [--fix, "--show-fixes"]
args: [
--fix,
--show-fixes,
# the following line can be removed after support for Python 3.9 is dropped
--extend-select=B905, # zip-without-explicit-strict
]

- repo: https:/pre-commit/pygrep-hooks
rev: v1.10.0
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ ignore = [
"E501", # line too long
"E741", # Do not use variables named 'I', 'O', or 'l'
"B018", # Found useless expression. # disabled because ds.index is idiomatic
"UP038", # non-pep604-isinstance
]

[tool.ruff.lint.per-file-ignores]
Expand Down
6 changes: 5 additions & 1 deletion yt/data_objects/analyzer_objects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import inspect
import sys

from yt.utilities.object_registries import analysis_task_registry

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


class AnalysisTask:
def __init_subclass__(cls, *args, **kwargs):
Expand All @@ -14,7 +18,7 @@ def __init__(self, *args, **kwargs):
# does not override
if len(args) + len(kwargs) != len(self._params):
raise RuntimeError
self.__dict__.update(zip(self._params, args))
self.__dict__.update(zip(self._params, args, strict=False))
self.__dict__.update(kwargs)

def __repr__(self):
Expand Down
19 changes: 14 additions & 5 deletions yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import fileinput
import io
import os
import sys
import warnings
import zipfile
from functools import partial, wraps
Expand Down Expand Up @@ -63,6 +64,9 @@
)
from yt.visualization.color_maps import get_colormap_lut

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


class YTStreamline(YTSelectionContainer1D):
"""
Expand Down Expand Up @@ -150,7 +154,9 @@ def _get_cut_mask(self, grid):
mask = np.zeros(points_in_grid.sum(), dtype="int64")
dts = np.zeros(points_in_grid.sum(), dtype="float64")
ts = np.zeros(points_in_grid.sum(), dtype="float64")
for mi, (i, pos) in enumerate(zip(pids, self.positions[points_in_grid])):
for mi, (i, pos) in enumerate(
zip(pids, self.positions[points_in_grid], strict=True)
):
if not points_in_grid[i]:
continue
ci = ((pos - grid.LeftEdge) / grid.dds).astype("int64")
Expand Down Expand Up @@ -784,7 +790,10 @@ def to_xarray(self, fields=None):
def icoords(self):
ic = np.indices(self.ActiveDimensions).astype("int64")
return np.column_stack(
[i.ravel() + gi for i, gi in zip(ic, self.get_global_startindex())]
[
i.ravel() + gi
for i, gi in zip(ic, self.get_global_startindex(), strict=True)
]
)

@property
Expand Down Expand Up @@ -1122,7 +1131,7 @@ def _fill_fields(self, fields):
if self.comm.size > 1:
for i in range(len(fields)):
output_fields[i] = self.comm.mpi_allreduce(output_fields[i], op="sum")
for field, v in zip(fields, output_fields):
for field, v in zip(fields, output_fields, strict=True):
fi = self.ds._get_field_info(field)
self[field] = self.ds.arr(v, fi.units)

Expand Down Expand Up @@ -1221,7 +1230,7 @@ def write_to_gdf(self, gdf_path, fields, nprocs=1, field_units=None, **kwargs):
data[field] = (self[field].in_units(units).v, units)
le = self.left_edge.v
re = self.right_edge.v
bbox = np.array([[l, r] for l, r in zip(le, re)])
bbox = np.array([[l, r] for l, r in zip(le, re, strict=True)])
ds = load_uniform_grid(
data,
self.ActiveDimensions,
Expand Down Expand Up @@ -1526,7 +1535,7 @@ def _fill_fields(self, fields):
stacklevel=1,
)
mylog.debug("Caught %d runtime errors.", runtime_errors_count)
for field, v in zip(fields, ls.fields):
for field, v in zip(fields, ls.fields, strict=True):
if self.level > 0:
v = v[1:-1, 1:-1, 1:-1]
fi = self.ds._get_field_info(field)
Expand Down
8 changes: 6 additions & 2 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import sys
import weakref
from collections import defaultdict
from contextlib import contextmanager
Expand Down Expand Up @@ -28,6 +29,9 @@
from yt.utilities.on_demand_imports import _firefly as firefly
from yt.utilities.parameter_file_storage import ParameterFileStore

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip

if TYPE_CHECKING:
from yt.data_objects.static_output import Dataset

Expand Down Expand Up @@ -797,7 +801,7 @@ def create_firefly_object(
# tuples containing some sort of special "any" ParticleGroup
unambiguous_fields_to_include = []
unambiguous_fields_units = []
for field, field_unit in zip(fields_to_include, fields_units):
for field, field_unit in zip(fields_to_include, fields_units, strict=True):
if isinstance(field, tuple):
# skip tuples, they'll be checked with _determine_fields
unambiguous_fields_to_include.append(field)
Expand Down Expand Up @@ -849,7 +853,7 @@ def create_firefly_object(
field_names = []

## explicitly go after the fields we want
for field, units in zip(fields_to_include, fields_units):
for field, units in zip(fields_to_include, fields_units, strict=True):
## Only interested in fields with the current particle type,
## whether that means general fields or field tuples
ftype, fname = field
Expand Down
11 changes: 8 additions & 3 deletions yt/data_objects/derived_quantities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import numpy as np

from yt.funcs import camelcase_to_underscore, iter_fields
Expand All @@ -11,6 +13,9 @@
from yt.utilities.physical_constants import gravitational_constant_cgs
from yt.utilities.physical_ratios import HUGE

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


def get_position_fields(field, data):
axis_names = [data.ds.coordinates.axis_name[num] for num in [0, 1, 2]]
Expand Down Expand Up @@ -415,7 +420,7 @@ def __call__(self, fields, weight):
fields = list(iter_fields(fields))
units = [self.data_source.ds._get_field_info(field).units for field in fields]
rv = super().__call__(fields, weight)
rv = [self.data_source.ds.arr(v, u) for v, u in zip(rv, units)]
rv = [self.data_source.ds.arr(v, u) for v, u in zip(rv, units, strict=True)]
if len(rv) == 1:
rv = rv[0]
return rv
Expand All @@ -431,7 +436,7 @@ def process_chunk(self, data, fields, weight):
my_var2s = [
(data[weight].d * (data[field].d - my_mean) ** 2).sum(dtype=np.float64)
/ my_weight
for field, my_mean in zip(fields, my_means)
for field, my_mean in zip(fields, my_means, strict=True)
]
return my_means + my_var2s + [my_weight]

Expand Down Expand Up @@ -623,7 +628,7 @@ def reduce_intermediate(self, values):
# The values get turned into arrays here.
return [
self.data_source.ds.arr([mis.min(), mas.max()])
for mis, mas in zip(values[::2], values[1::2])
for mis, mas in zip(values[::2], values[1::2], strict=True)
]


Expand Down
8 changes: 6 additions & 2 deletions yt/data_objects/level_sets/tests/test_clump_finding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import sys
import tempfile

import numpy as np
Expand All @@ -12,6 +13,9 @@
from yt.testing import requires_file, requires_module
from yt.utilities.answer_testing.framework import data_dir_load

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


def test_clump_finding():
n_c = 8
Expand Down Expand Up @@ -132,7 +136,7 @@ def test_clump_tree_save():
it2 = np.argsort(mt2).astype("int64")
assert_array_equal(mt1[it1], mt2[it2])

for i1, i2 in zip(it1, it2):
for i1, i2 in zip(it1, it2, strict=True):
ct1 = t1[i1]
ct2 = t2[i2]
assert_array_equal(ct1["gas", "density"], ct2["grid", "density"])
Expand Down Expand Up @@ -191,5 +195,5 @@ def _also_density(field, data):
leaf_clumps_1 = master_clump_1.leaves
leaf_clumps_2 = master_clump_2.leaves

for c1, c2 in zip(leaf_clumps_1, leaf_clumps_2):
for c1, c2 in zip(leaf_clumps_1, leaf_clumps_2, strict=True):
assert_array_equal(c1["gas", "density"], c2["gas", "density"])
19 changes: 14 additions & 5 deletions yt/data_objects/profiles.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import numpy as np
from more_itertools import collapse

Expand All @@ -23,6 +25,9 @@
parallel_objects,
)

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


def _sanitize_min_max_units(amin, amax, finfo, registry):
# returns a copy of amin and amax, converted to finfo's output units
Expand Down Expand Up @@ -217,7 +222,7 @@ def _filter(self, bin_fields):
# cut_points is set to be everything initially, but
# we also want to apply a filtering based on min/max
pfilter = np.ones(bin_fields[0].shape, dtype="bool")
for (mi, ma), data in zip(self.bounds, bin_fields):
for (mi, ma), data in zip(self.bounds, bin_fields, strict=True):
pfilter &= data > mi
pfilter &= data < ma
return pfilter, [data[pfilter] for data in bin_fields]
Expand Down Expand Up @@ -1313,7 +1318,11 @@ def create_profile(
data_source.ds.field_info[f].sampling_type == "local"
for f in bin_fields + fields
]
is_local_or_pfield = [pf or lf for (pf, lf) in zip(is_pfield, is_local)]
if wf is not None:
is_local.append(wf.sampling_type == "local")
is_local_or_pfield = [
pf or lf for (pf, lf) in zip(is_pfield, is_local, strict=True)
]
if not all(is_local_or_pfield):
raise YTIllDefinedProfile(
bin_fields, data_source._determine_fields(fields), wf, is_pfield
Expand Down Expand Up @@ -1370,7 +1379,7 @@ def create_profile(
if extrema is None or not any(collapse(extrema.values())):
ex = [
data_source.quantities["Extrema"](f, non_zero=l)
for f, l in zip(bin_fields, logs)
for f, l in zip(bin_fields, logs, strict=True)
]
# pad extrema by epsilon so cells at bin edges are not excluded
for i, (mi, ma) in enumerate(ex):
Expand Down Expand Up @@ -1456,15 +1465,15 @@ def create_profile(
o_bins.append(field_obin)

args = [data_source]
for f, n, (mi, ma), l in zip(bin_fields, n_bins, ex, logs):
for f, n, (mi, ma), l in zip(bin_fields, n_bins, ex, logs, strict=True):
if mi <= 0 and l:
raise YTIllDefinedBounds(mi, ma)
args += [f, n, mi, ma, l]
kwargs = {"weight_field": weight_field}
if cls is ParticleProfile:
kwargs["deposition"] = deposition
if override_bins is not None:
for o_bin, ax in zip(o_bins, ["x", "y", "z"]):
for o_bin, ax in zip(o_bins, ["x", "y", "z"], strict=False):
kwargs[f"override_bins_{ax}"] = o_bin
obj = cls(*args, **kwargs)
obj.accumulation = accumulation
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/region_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _create_region(self, bounds_tuple):
if d is not None:
d = int(d)
dims[ax] = d
center = [(cl + cr) / 2.0 for cl, cr in zip(left_edge, right_edge)]
center = (left_edge + right_edge) / 2.0
if None not in dims:
return self.ds.arbitrary_grid(left_edge, right_edge, dims)
return self.ds.region(center, left_edge, right_edge)
Expand Down
8 changes: 7 additions & 1 deletion yt/data_objects/tests/test_covering_grid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import numpy as np
from numpy.testing import assert_almost_equal, assert_array_equal, assert_equal

Expand All @@ -11,6 +13,10 @@
)
from yt.units import kpc

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


# cylindrical data for covering_grid test
cyl_2d = "WDMerger_hdf5_chk_1000/WDMerger_hdf5_chk_1000.hdf5"
cyl_3d = "MHD_Cyl3d_hdf5_plt_cnt_0100/MHD_Cyl3d_hdf5_plt_cnt_0100.hdf5"
Expand Down Expand Up @@ -354,7 +360,7 @@ def test_arbitrary_grid_edge():
[1.0, 1.0, 1.0] * kpc,
]

for le, re, le_ans, re_ans in zip(ledge, redge, ledge_ans, redge_ans):
for le, re, le_ans, re_ans in zip(ledge, redge, ledge_ans, redge_ans, strict=True):
ag = ds.arbitrary_grid(left_edge=le, right_edge=re, dims=dims)
assert np.array_equal(ag.left_edge, le_ans)
assert np.array_equal(ag.right_edge, re_ans)
Expand Down
13 changes: 11 additions & 2 deletions yt/data_objects/tests/test_derived_quantities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import numpy as np
from numpy.testing import assert_almost_equal, assert_equal

Expand All @@ -11,6 +13,9 @@
requires_file,
)

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


def setup_module():
from yt.config import ytcfg
Expand Down Expand Up @@ -183,10 +188,14 @@ def test_in_memory_sph_derived_quantities():
assert_equal(com, [1 / 7, (1 + 2) / 7, (1 + 2 + 3) / 7])

ex = ad.quantities.extrema([("io", "x"), ("io", "y"), ("io", "z")])
for fex, ans in zip(ex, [[0, 1], [0, 2], [0, 3]]):
for fex, ans in zip(ex, [[0, 1], [0, 2], [0, 3]], strict=True):
assert_equal(fex, ans)

for d, v, l in zip("xyz", [1, 2, 3], [[1, 0, 0], [0, 2, 0], [0, 0, 3]]):
for d, v, l in [
("x", 1, [1, 0, 0]),
("y", 2, [0, 2, 0]),
("z", 3, [0, 0, 3]),
]:
max_d, x, y, z = ad.quantities.max_location(("io", d))
assert_equal(max_d, v)
assert_equal([x, y, z], l)
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/tests/test_sph_data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_periodic_region():
for y in coords:
for z in coords:
center = np.array([x, y, z])
for n, w in zip((8, 27), (1.0, 2.0)):
for n, w in [(8, 1.0), (27, 2.0)]:
le = center - 0.5 * w
re = center + 0.5 * w
box = ds.box(le, re)
Expand Down
6 changes: 5 additions & 1 deletion yt/fields/derived_field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import inspect
import re
import sys
from collections.abc import Iterable
from typing import Optional, Union

Expand All @@ -25,6 +26,9 @@
NeedsProperty,
)

if sys.version_info < (3, 10):
from yt._maintenance.backports import zip


def TranslationFunc(field_name):
def _TranslationFunc(field, data):
Expand Down Expand Up @@ -255,7 +259,7 @@ def _get_needed_parameters(self, fd):
else:
params.extend(val.parameters)
values.extend([fd.get_field_parameter(fp) for fp in val.parameters])
return dict(zip(params, values)), permute_params
return dict(zip(params, values, strict=True)), permute_params

_unit_registry = None

Expand Down
Loading

0 comments on commit bc63757

Please sign in to comment.