diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner.py b/datasets/flwr_datasets/partitioner/iid_partitioner.py index c8dbf8294fe..37b97468cad 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner.py @@ -48,5 +48,5 @@ def load_partition(self, idx: int) -> datasets.Dataset: single dataset partition """ return self.dataset.shard( - num_shards=self._num_partitions, index=idx, contiguous=True + num_shards=self._num_partitions, index=idx, contiguous=False ) diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py index d89eefeba9f..5f851807f4b 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py @@ -18,6 +18,7 @@ import unittest from typing import Tuple +import numpy as np from parameterized import parameterized from datasets import Dataset @@ -100,11 +101,16 @@ def test_load_partition_correct_data( self, num_partitions: int, num_rows: int ) -> None: """Test if the data in partition is equal to the expected.""" - _, partitioner = _dummy_setup(num_partitions, num_rows) - partition_size = num_rows // num_partitions + dataset, partitioner = _dummy_setup(num_partitions, num_rows) partition_index = 2 partition = partitioner.load_partition(partition_index) - self.assertEqual(partition["features"][0], partition_index * partition_size) + row_id = 0 + self.assertEqual( + partition["features"][row_id], + dataset[np.arange(partition_index, len(dataset), num_partitions)][ + "features" + ][row_id], + ) @parameterized.expand( # type: ignore [