Skip to content

Commit

Permalink
Fix default contiguous value in IidPartitioner (#2406)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Taner Topal <[email protected]>
  • Loading branch information
adam-narozniak and tanertopal authored Sep 22, 2023
1 parent b569d2a commit b63b775
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion datasets/flwr_datasets/partitioner/iid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def load_partition(self, idx: int) -> datasets.Dataset:
single dataset partition
"""
return self.dataset.shard(
num_shards=self._num_partitions, index=idx, contiguous=True
num_shards=self._num_partitions, index=idx, contiguous=False
)
12 changes: 9 additions & 3 deletions datasets/flwr_datasets/partitioner/iid_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
from typing import Tuple

import numpy as np
from parameterized import parameterized

from datasets import Dataset
Expand Down Expand Up @@ -100,11 +101,16 @@ def test_load_partition_correct_data(
self, num_partitions: int, num_rows: int
) -> None:
"""Test if the data in partition is equal to the expected."""
_, partitioner = _dummy_setup(num_partitions, num_rows)
partition_size = num_rows // num_partitions
dataset, partitioner = _dummy_setup(num_partitions, num_rows)
partition_index = 2
partition = partitioner.load_partition(partition_index)
self.assertEqual(partition["features"][0], partition_index * partition_size)
row_id = 0
self.assertEqual(
partition["features"][row_id],
dataset[np.arange(partition_index, len(dataset), num_partitions)][
"features"
][row_id],
)

@parameterized.expand( # type: ignore
[
Expand Down

0 comments on commit b63b775

Please sign in to comment.