From dd7aa9912a70fe5c445b65e8627bd421b73f8e5d Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 21 Sep 2023 15:35:49 +0200 Subject: [PATCH 1/2] Fix default contiguous value in IidPartitioner --- datasets/flwr_datasets/partitioner/iid_partitioner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner.py b/datasets/flwr_datasets/partitioner/iid_partitioner.py index c8dbf8294fe..37b97468cad 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner.py @@ -48,5 +48,5 @@ def load_partition(self, idx: int) -> datasets.Dataset: single dataset partition """ return self.dataset.shard( - num_shards=self._num_partitions, index=idx, contiguous=True + num_shards=self._num_partitions, index=idx, contiguous=False ) From 5398784020c6081ec1e33f40366f5caad720e540 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Fri, 22 Sep 2023 10:02:52 +0200 Subject: [PATCH 2/2] Fix tests --- .../partitioner/iid_partitioner_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py index d89eefeba9f..5f851807f4b 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py @@ -18,6 +18,7 @@ import unittest from typing import Tuple +import numpy as np from parameterized import parameterized from datasets import Dataset @@ -100,11 +101,16 @@ def test_load_partition_correct_data( self, num_partitions: int, num_rows: int ) -> None: """Test if the data in partition is equal to the expected.""" - _, partitioner = _dummy_setup(num_partitions, num_rows) - partition_size = num_rows // num_partitions + dataset, partitioner = _dummy_setup(num_partitions, num_rows) partition_index = 2 partition = partitioner.load_partition(partition_index) - self.assertEqual(partition["features"][0], partition_index * partition_size) + row_id = 0 + self.assertEqual( + partition["features"][row_id], + dataset[np.arange(partition_index, len(dataset), num_partitions)][ + "features" + ][row_id], + ) @parameterized.expand( # type: ignore [