diff --git a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java index 655154bb880eb..72d134602e88f 100644 --- a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java @@ -33,8 +33,10 @@ import org.opensearch.common.util.set.Sets; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.store.Store; import org.opensearch.indices.store.ShardAttributes; import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch; +import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper; import java.util.Collections; import java.util.HashMap; @@ -46,6 +48,10 @@ import java.util.Set; import java.util.Spliterators; import java.util.concurrent.ConcurrentMap; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -57,7 +63,6 @@ public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator { public static final String ALLOCATOR_NAME = "shards_batch_gateway_allocator"; - private static final Logger logger = LogManager.getLogger(ShardsBatchGatewayAllocator.class); private final long maxBatchSize; private static final short DEFAULT_SHARD_BATCH_SIZE = 2000; @@ -74,25 +79,22 @@ public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator { ); private final RerouteService rerouteService; - - private PrimaryShardBatchAllocator primaryShardBatchAllocator; - private ReplicaShardBatchAllocator replicaShardBatchAllocator; - + private final PrimaryShardBatchAllocator primaryShardBatchAllocator; + private final ReplicaShardBatchAllocator replicaShardBatchAllocator; private Set lastSeenEphemeralIds = Collections.emptySet(); - // visble for testing + // visible for testing protected final ConcurrentMap batchIdToStartedShardBatch = ConcurrentCollections.newConcurrentMap(); // visible for testing protected final ConcurrentMap batchIdToStoreShardBatch = ConcurrentCollections.newConcurrentMap(); - - private final TransportNodesListGatewayStartedBatchShards batchStartedAction; + private final TransportNodesListGatewayStartedShardsBatch batchStartedAction; private final TransportNodesListShardStoreMetadataBatch batchStoreAction; @Inject public ShardsBatchGatewayAllocator( RerouteService rerouteService, - TransportNodesListGatewayStartedBatchShards batchStartedAction, + TransportNodesListGatewayStartedShardsBatch batchStartedAction, TransportNodesListShardStoreMetadataBatch batchStoreAction, Settings settings ) { @@ -165,7 +167,7 @@ public void beforeAllocation(final RoutingAllocation allocation) { @Override public void afterPrimariesBeforeReplicas(RoutingAllocation allocation) { assert replicaShardBatchAllocator != null; - List> storedShardBatches = batchIdToStoreShardBatch.values() + List> storedShardBatches = batchIdToStoreShardBatch.values() .stream() .map(ShardsBatch::getBatchedShardRoutings) .collect(Collectors.toList()); @@ -254,7 +256,6 @@ else if (shardRouting.primary() == primary) { if (batchSize > 0) { ShardEntry shardEntry = new ShardEntry( new ShardAttributes( - currentShard.shardId(), IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(currentShard.index()).getSettings()) ), currentShard @@ -424,15 +425,31 @@ private boolean hasNewNodes(DiscoveryNodes nodes) { return false; } - class InternalBatchAsyncFetch extends AsyncShardFetch { + class InternalBatchAsyncFetch extends AsyncShardBatchFetch { InternalBatchAsyncFetch( Logger logger, String type, Map map, AsyncShardFetch.Lister, T> action, - String batchUUId + String batchUUId, + Class clazz, + BiFunction, T> responseBuilder, + Function> shardsBatchDataGetter, + Supplier emptyResponseBuilder, + Consumer handleFailedShard ) { - super(logger, type, map, action, batchUUId); + super( + logger, + type, + map, + action, + batchUUId, + clazz, + responseBuilder, + shardsBatchDataGetter, + emptyResponseBuilder, + handleFailedShard + ); } @Override @@ -454,12 +471,12 @@ class InternalPrimaryBatchShardAllocator extends PrimaryShardBatchAllocator { @Override @SuppressWarnings("unchecked") - protected AsyncShardFetch.FetchResult fetchData( - Set shardsEligibleForFetch, - Set inEligibleShards, + protected AsyncShardFetch.FetchResult fetchData( + List eligibleShards, + List inEligibleShards, RoutingAllocation allocation ) { - ShardRouting shardRouting = shardsEligibleForFetch.iterator().hasNext() ? shardsEligibleForFetch.iterator().next() : null; + ShardRouting shardRouting = eligibleShards.iterator().hasNext() ? eligibleShards.iterator().next() : null; shardRouting = shardRouting == null && inEligibleShards.iterator().hasNext() ? inEligibleShards.iterator().next() : shardRouting; @@ -481,7 +498,7 @@ protected AsyncShardFetch.FetchResult(null, Collections.emptyMap()); } @@ -491,7 +508,7 @@ protected AsyncShardFetch.FetchResult asyncFetcher = shardsBatch.getAsyncFetcher(); + AsyncShardBatchFetch asyncFetcher = shardsBatch.getAsyncFetcher(); AsyncShardFetch.FetchResult shardBatchState = asyncFetcher.fetchData( allocation.nodes(), shardToIgnoreNodes @@ -500,7 +517,7 @@ protected AsyncShardFetch.FetchResult) shardBatchState; + return (AsyncShardFetch.FetchResult) shardBatchState; } } @@ -509,12 +526,12 @@ class InternalReplicaBatchShardAllocator extends ReplicaShardBatchAllocator { @Override @SuppressWarnings("unchecked") protected AsyncShardFetch.FetchResult fetchData( - Set shardsEligibleForFetch, - Set inEligibleShards, + List eligibleShards, + List inEligibleShards, RoutingAllocation allocation ) { - // get batch id for anyone given shard. We are assuming all shards will have same batch Id - ShardRouting shardRouting = shardsEligibleForFetch.iterator().hasNext() ? shardsEligibleForFetch.iterator().next() : null; + // get batch id for anyone given shard. We are assuming all shards will have same batchId + ShardRouting shardRouting = eligibleShards.iterator().hasNext() ? eligibleShards.iterator().next() : null; shardRouting = shardRouting == null && inEligibleShards.iterator().hasNext() ? inEligibleShards.iterator().next() : shardRouting; @@ -536,7 +553,7 @@ protected AsyncShardFetch.FetchResult(null, Collections.emptyMap()); } @@ -544,7 +561,7 @@ protected AsyncShardFetch.FetchResult asyncFetcher = shardsBatch.getAsyncFetcher(); + AsyncShardBatchFetch asyncFetcher = shardsBatch.getAsyncFetcher(); AsyncShardFetch.FetchResult shardBatchStores = asyncFetcher.fetchData( allocation.nodes(), shardToIgnoreNodes @@ -565,14 +582,14 @@ protected boolean hasInitiatedFetching(ShardRouting shard) { /** * Holds information about a batch of shards to be allocated. * Async fetcher is used to fetch the data for the batch. - * + *

* Visible for testing */ public class ShardsBatch { private final String batchId; private final boolean primary; - private final AsyncShardFetch asyncBatch; + private final InternalBatchAsyncFetch asyncBatch; private final Map batchInfo; @@ -584,23 +601,59 @@ public ShardsBatch(String batchId, Map shardsWithInfo, bool .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getShardAttributes())); this.primary = primary; - if (primary) { - asyncBatch = new InternalBatchAsyncFetch<>(logger, "batch_shards_started", shardIdsMap, batchStartedAction, batchId); + if (this.primary) { + asyncBatch = new InternalBatchAsyncFetch<>( + logger, + "batch_shards_started", + shardIdsMap, + batchStartedAction, + batchId, + TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard.class, + TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::new, + TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch, + () -> new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(null, false, null, null), + this::removeShard + ); } else { - asyncBatch = new InternalBatchAsyncFetch<>(logger, "batch_shards_started", shardIdsMap, batchStoreAction, batchId); - + asyncBatch = new InternalBatchAsyncFetch<>( + logger, + "batch_shards_store", + shardIdsMap, + batchStoreAction, + batchId, + TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata.class, + TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::new, + TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::getNodeStoreFilesMetadataBatch, + this::buildEmptyReplicaShardResponse, + this::removeShard + ); } } + protected void removeShard(ShardId shardId) { + this.batchInfo.remove(shardId); + } + + private TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata buildEmptyReplicaShardResponse() { + return new TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata( + new TransportNodesListShardStoreMetadataHelper.StoreFilesMetadata( + null, + Store.MetadataSnapshot.EMPTY, + Collections.emptyList() + ), + null + ); + } + private void removeFromBatch(ShardRouting shard) { - batchInfo.remove(shard.shardId()); - asyncBatch.shardAttributesMap.remove(shard.shardId()); + removeShard(shard.shardId()); + asyncBatch.clearShard(shard.shardId()); // assert that fetcher and shards are the same as batched shards assert batchInfo.size() == asyncBatch.shardAttributesMap.size() : "Shards size is not equal to fetcher size"; } - public Set getBatchedShardRoutings() { - return batchInfo.values().stream().map(ShardEntry::getShardRouting).collect(Collectors.toSet()); + public List getBatchedShardRoutings() { + return batchInfo.values().stream().map(ShardEntry::getShardRouting).collect(Collectors.toList()); } public Set getBatchedShards() { @@ -611,7 +664,7 @@ public String getBatchId() { return batchId; } - public AsyncShardFetch getAsyncFetcher() { + public AsyncShardBatchFetch getAsyncFetcher() { return asyncBatch; } @@ -624,7 +677,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || o instanceof ShardsBatch == false) { + if (o instanceof ShardsBatch == false) { return false; } ShardsBatch shardsBatch = (ShardsBatch) o; @@ -646,15 +699,10 @@ public String toString() { /** * Holds information about a shard to be allocated in a batch. */ - private class ShardEntry { + static class ShardEntry { private final ShardAttributes shardAttributes; - public ShardEntry setShardRouting(ShardRouting shardRouting) { - this.shardRouting = shardRouting; - return this; - } - private ShardRouting shardRouting; public ShardEntry(ShardAttributes shardAttributes, ShardRouting shardRouting) { @@ -669,6 +717,11 @@ public ShardRouting getShardRouting() { public ShardAttributes getShardAttributes() { return shardAttributes; } + + public ShardEntry setShardRouting(ShardRouting shardRouting) { + this.shardRouting = shardRouting; + return this; + } } public int getNumberOfStartedShardBatches() { diff --git a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java index e34c222c94205..c7fbfe40d82a7 100644 --- a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java +++ b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -36,8 +37,8 @@ public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator PrimaryShardBatchAllocator primaryBatchShardAllocator = new PrimaryShardBatchAllocator() { @Override protected AsyncShardFetch.FetchResult fetchData( - Set shardsEligibleForFetch, - Set inEligibleShards, + List eligibleShards, + List inEligibleShards, RoutingAllocation allocation ) { Map foundShards = new HashMap<>(); @@ -47,7 +48,7 @@ protected AsyncShardFetch.FetchResult shardsOnNode = entry.getValue(); HashMap adaptedResponse = new HashMap<>(); - for (ShardRouting shardRouting : shardsEligibleForFetch) { + for (ShardRouting shardRouting : eligibleShards) { ShardId shardId = shardRouting.shardId(); Set ignoreNodes = allocation.getIgnoreNodes(shardId); @@ -78,8 +79,8 @@ protected AsyncShardFetch.FetchResult fetchData( - Set shardsEligibleForFetch, - Set inEligibleShards, + List eligibleShards, + List inEligibleShards, RoutingAllocation allocation ) { return new AsyncShardFetch.FetchResult<>(Collections.emptyMap(), Collections.emptyMap());