Skip to content

Commit

Permalink
avoid particle_index type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Sep 19, 2024
1 parent 43af362 commit 59b5d85
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
3 changes: 2 additions & 1 deletion yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def get_data(self, fields=None):

for f, v in read_particles.items():
self.field_data[f] = self.ds.arr(v, units=finfos[f].units)
self.field_data[f].convert_to_units(finfos[f].output_units)
if finfos[f].units != finfos[f].output_units:
self.field_data[f].convert_to_units(finfos[f].output_units)

fields_to_generate += gen_fluids + gen_particles
self._generate_fields(fields_to_generate)
Expand Down
29 changes: 29 additions & 0 deletions yt/frontends/stream/tests/test_stream_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,32 @@ def test_stream_non_cartesian_particles_amr():
assert_equal(dd["all", "particle_position_r"].v, particle_position_r)
assert_equal(dd["all", "particle_position_phi"].v, particle_position_phi)
assert_equal(dd["all", "particle_position_theta"].v, particle_position_theta)


def test_particle_dtypes():
num_particles = 100

data = {
("gas", "particle_position_x"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_position_y"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_position_z"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_mass"): np.ones(num_particles),
("gas", "density"): np.ones(num_particles),
("gas", "smoothing_length"): np.ones(num_particles) * 0.1,
("gas", "particle_index"): np.arange(0, num_particles),
}

ds = load_particles(data)

assert ds.all_data()["gas", "particle_index"].dtype == np.int64

le = ds.domain_center - ds.domain_width / 10.0
re = ds.domain_center + ds.domain_width / 10.0
reg = ds.region(ds.domain_center, le, re)
assert reg["gas", "particle_index"].dtype == np.int64

# make sure we can do operations that require cython
proj = ds.proj(("gas", "particle_index"), 2)
frb = proj.to_frb(ds.domain_width[0], (256, 256))
image = frb["gas", "particle_index"]
assert image.max() > 0
15 changes: 13 additions & 2 deletions yt/utilities/io_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ class BaseIOHandler:
_misses = 0
_hits = 0

@property
def _particle_dtypes(self) -> defaultdict[FieldKey, type]:
dtypes: defaultdict[FieldKey, type] = defaultdict(lambda: np.float64)
for ptype in self.ds.particle_types:
p_index = (ptype, "particle_index")
if p_index in self.ds.field_info.field_aliases:
p_index = self.ds.field_info.field_aliases[p_index]
dtypes[p_index] = np.int64
return dtypes

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
if hasattr(cls, "_dataset_type"):
Expand Down Expand Up @@ -173,6 +183,7 @@ def _read_particle_selection(
# field_maps stores fields, accounting for field unions
ptf: defaultdict[str, list[str]] = defaultdict(list)
field_maps: defaultdict[FieldKey, list[FieldKey]] = defaultdict(list)
p_dtypes = self._particle_dtypes

# We first need a set of masks for each particle type
chunks = list(chunks)
Expand Down Expand Up @@ -206,14 +217,14 @@ def _read_particle_selection(
vals = data.pop(field_f)
# note: numpy.concatenate has a dtype argument that would avoid
# a copy using .astype(...), available in numpy>=1.20
rv[field_f] = np.concatenate(vals, axis=0).astype("float64")
rv[field_f] = np.concatenate(vals, axis=0).astype(p_dtypes[field_f])
else:
shape = [0]
if field_f[1] in self._vector_fields:
shape.append(self._vector_fields[field_f[1]])
elif field_f[1] in self._array_fields:
shape.append(self._array_fields[field_f[1]])
rv[field_f] = np.empty(shape, dtype="float64")
rv[field_f] = np.empty(shape, dtype=p_dtypes[field_f])
return rv

def _read_particle_fields(self, chunks, ptf, selector):
Expand Down

0 comments on commit 59b5d85

Please sign in to comment.