Skip to content

Commit

Permalink
Use new AsyncShardFetchBatch class for creating cache for batch trans…
Browse files Browse the repository at this point in the history
…port actions

Signed-off-by: Aman Khare <[email protected]>
Signed-off-by: Shivansh Arora <[email protected]>
  • Loading branch information
Aman Khare authored and shiv0408 committed Apr 25, 2024
1 parent 5dedfc9 commit 8a50ff2
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.store.Store;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -46,6 +48,10 @@
import java.util.Set;
import java.util.Spliterators;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

Expand All @@ -57,7 +63,6 @@
public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator {

public static final String ALLOCATOR_NAME = "shards_batch_gateway_allocator";

private static final Logger logger = LogManager.getLogger(ShardsBatchGatewayAllocator.class);
private final long maxBatchSize;
private static final short DEFAULT_SHARD_BATCH_SIZE = 2000;
Expand All @@ -74,25 +79,22 @@ public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator {
);

private final RerouteService rerouteService;

private PrimaryShardBatchAllocator primaryShardBatchAllocator;
private ReplicaShardBatchAllocator replicaShardBatchAllocator;

private final PrimaryShardBatchAllocator primaryShardBatchAllocator;
private final ReplicaShardBatchAllocator replicaShardBatchAllocator;
private Set<String> lastSeenEphemeralIds = Collections.emptySet();

// visble for testing
// visible for testing
protected final ConcurrentMap<String, ShardsBatch> batchIdToStartedShardBatch = ConcurrentCollections.newConcurrentMap();

// visible for testing
protected final ConcurrentMap<String, ShardsBatch> batchIdToStoreShardBatch = ConcurrentCollections.newConcurrentMap();

private final TransportNodesListGatewayStartedBatchShards batchStartedAction;
private final TransportNodesListGatewayStartedShardsBatch batchStartedAction;
private final TransportNodesListShardStoreMetadataBatch batchStoreAction;

@Inject
public ShardsBatchGatewayAllocator(
RerouteService rerouteService,
TransportNodesListGatewayStartedBatchShards batchStartedAction,
TransportNodesListGatewayStartedShardsBatch batchStartedAction,
TransportNodesListShardStoreMetadataBatch batchStoreAction,
Settings settings
) {
Expand Down Expand Up @@ -165,7 +167,7 @@ public void beforeAllocation(final RoutingAllocation allocation) {
@Override
public void afterPrimariesBeforeReplicas(RoutingAllocation allocation) {
assert replicaShardBatchAllocator != null;
List<Set<ShardRouting>> storedShardBatches = batchIdToStoreShardBatch.values()
List<List<ShardRouting>> storedShardBatches = batchIdToStoreShardBatch.values()
.stream()
.map(ShardsBatch::getBatchedShardRoutings)
.collect(Collectors.toList());
Expand Down Expand Up @@ -254,7 +256,6 @@ else if (shardRouting.primary() == primary) {
if (batchSize > 0) {
ShardEntry shardEntry = new ShardEntry(
new ShardAttributes(
currentShard.shardId(),
IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(currentShard.index()).getSettings())
),
currentShard
Expand Down Expand Up @@ -424,15 +425,31 @@ private boolean hasNewNodes(DiscoveryNodes nodes) {
return false;
}

class InternalBatchAsyncFetch<T extends BaseNodeResponse> extends AsyncShardFetch<T> {
class InternalBatchAsyncFetch<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardBatchFetch<T, V> {
InternalBatchAsyncFetch(
Logger logger,
String type,
Map<ShardId, ShardAttributes> map,
AsyncShardFetch.Lister<? extends BaseNodesResponse<T>, T> action,
String batchUUId
String batchUUId,
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseBuilder,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> handleFailedShard
) {
super(logger, type, map, action, batchUUId);
super(
logger,
type,
map,
action,
batchUUId,
clazz,
responseBuilder,
shardsBatchDataGetter,
emptyResponseBuilder,
handleFailedShard
);
}

@Override
Expand All @@ -454,12 +471,12 @@ class InternalPrimaryBatchShardAllocator extends PrimaryShardBatchAllocator {

@Override
@SuppressWarnings("unchecked")
protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedBatchShards.NodeGatewayStartedShardsBatch> fetchData(
Set<ShardRouting> shardsEligibleForFetch,
Set<ShardRouting> inEligibleShards,
protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch> fetchData(
List<ShardRouting> eligibleShards,
List<ShardRouting> inEligibleShards,
RoutingAllocation allocation
) {
ShardRouting shardRouting = shardsEligibleForFetch.iterator().hasNext() ? shardsEligibleForFetch.iterator().next() : null;
ShardRouting shardRouting = eligibleShards.iterator().hasNext() ? eligibleShards.iterator().next() : null;
shardRouting = shardRouting == null && inEligibleShards.iterator().hasNext()
? inEligibleShards.iterator().next()
: shardRouting;
Expand All @@ -481,7 +498,7 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedBatchShard
// remove in eligible shards which allocator is not responsible for
inEligibleShards.forEach(ShardsBatchGatewayAllocator.this::safelyRemoveShardFromBatch);

if (shardsBatch.getBatchedShards().isEmpty() && shardsEligibleForFetch.isEmpty()) {
if (shardsBatch.getBatchedShards().isEmpty() && eligibleShards.isEmpty()) {
logger.debug("Batch {} is empty", batchId);
return new AsyncShardFetch.FetchResult<>(null, Collections.emptyMap());
}
Expand All @@ -491,7 +508,7 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedBatchShard
for (ShardId shardId : shardsBatch.asyncBatch.shardAttributesMap.keySet()) {
shardToIgnoreNodes.put(shardId, allocation.getIgnoreNodes(shardId));
}
AsyncShardFetch<? extends BaseNodeResponse> asyncFetcher = shardsBatch.getAsyncFetcher();
AsyncShardBatchFetch<? extends BaseNodeResponse, ? extends BaseShardResponse> asyncFetcher = shardsBatch.getAsyncFetcher();
AsyncShardFetch.FetchResult<? extends BaseNodeResponse> shardBatchState = asyncFetcher.fetchData(
allocation.nodes(),
shardToIgnoreNodes
Expand All @@ -500,7 +517,7 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedBatchShard
if (shardBatchState.hasData()) {
shardBatchState.processAllocation(allocation);
}
return (AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedBatchShards.NodeGatewayStartedShardsBatch>) shardBatchState;
return (AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch>) shardBatchState;
}

}
Expand All @@ -509,12 +526,12 @@ class InternalReplicaBatchShardAllocator extends ReplicaShardBatchAllocator {
@Override
@SuppressWarnings("unchecked")
protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch> fetchData(
Set<ShardRouting> shardsEligibleForFetch,
Set<ShardRouting> inEligibleShards,
List<ShardRouting> eligibleShards,
List<ShardRouting> inEligibleShards,
RoutingAllocation allocation
) {
// get batch id for anyone given shard. We are assuming all shards will have same batch Id
ShardRouting shardRouting = shardsEligibleForFetch.iterator().hasNext() ? shardsEligibleForFetch.iterator().next() : null;
// get batch id for anyone given shard. We are assuming all shards will have same batchId
ShardRouting shardRouting = eligibleShards.iterator().hasNext() ? eligibleShards.iterator().next() : null;
shardRouting = shardRouting == null && inEligibleShards.iterator().hasNext()
? inEligibleShards.iterator().next()
: shardRouting;
Expand All @@ -536,15 +553,15 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.
// remove in eligible shards which allocator is not responsible for
inEligibleShards.forEach(ShardsBatchGatewayAllocator.this::safelyRemoveShardFromBatch);

if (shardsBatch.getBatchedShards().isEmpty() && shardsEligibleForFetch.isEmpty()) {
if (shardsBatch.getBatchedShards().isEmpty() && eligibleShards.isEmpty()) {
logger.debug("Batch {} is empty", batchId);
return new AsyncShardFetch.FetchResult<>(null, Collections.emptyMap());
}
Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();
for (ShardId shardId : shardsBatch.asyncBatch.shardAttributesMap.keySet()) {
shardToIgnoreNodes.put(shardId, allocation.getIgnoreNodes(shardId));
}
AsyncShardFetch<? extends BaseNodeResponse> asyncFetcher = shardsBatch.getAsyncFetcher();
AsyncShardBatchFetch<? extends BaseNodeResponse, ? extends BaseShardResponse> asyncFetcher = shardsBatch.getAsyncFetcher();
AsyncShardFetch.FetchResult<? extends BaseNodeResponse> shardBatchStores = asyncFetcher.fetchData(
allocation.nodes(),
shardToIgnoreNodes
Expand All @@ -565,14 +582,14 @@ protected boolean hasInitiatedFetching(ShardRouting shard) {
/**
* Holds information about a batch of shards to be allocated.
* Async fetcher is used to fetch the data for the batch.
*
* <p>
* Visible for testing
*/
public class ShardsBatch {
private final String batchId;
private final boolean primary;

private final AsyncShardFetch<? extends BaseNodeResponse> asyncBatch;
private final InternalBatchAsyncFetch<? extends BaseNodeResponse, ? extends BaseShardResponse> asyncBatch;

private final Map<ShardId, ShardEntry> batchInfo;

Expand All @@ -584,23 +601,59 @@ public ShardsBatch(String batchId, Map<ShardId, ShardEntry> shardsWithInfo, bool
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getShardAttributes()));
this.primary = primary;
if (primary) {
asyncBatch = new InternalBatchAsyncFetch<>(logger, "batch_shards_started", shardIdsMap, batchStartedAction, batchId);
if (this.primary) {
asyncBatch = new InternalBatchAsyncFetch<>(
logger,
"batch_shards_started",
shardIdsMap,
batchStartedAction,
batchId,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard.class,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::new,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch,
() -> new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(null, false, null, null),
this::removeShard
);
} else {
asyncBatch = new InternalBatchAsyncFetch<>(logger, "batch_shards_started", shardIdsMap, batchStoreAction, batchId);

asyncBatch = new InternalBatchAsyncFetch<>(
logger,
"batch_shards_store",
shardIdsMap,
batchStoreAction,
batchId,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata.class,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::new,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::getNodeStoreFilesMetadataBatch,
this::buildEmptyReplicaShardResponse,
this::removeShard
);
}
}

protected void removeShard(ShardId shardId) {
this.batchInfo.remove(shardId);
}

private TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata buildEmptyReplicaShardResponse() {
return new TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata(
new TransportNodesListShardStoreMetadataHelper.StoreFilesMetadata(
null,
Store.MetadataSnapshot.EMPTY,
Collections.emptyList()
),
null
);
}

private void removeFromBatch(ShardRouting shard) {
batchInfo.remove(shard.shardId());
asyncBatch.shardAttributesMap.remove(shard.shardId());
removeShard(shard.shardId());
asyncBatch.clearShard(shard.shardId());
// assert that fetcher and shards are the same as batched shards
assert batchInfo.size() == asyncBatch.shardAttributesMap.size() : "Shards size is not equal to fetcher size";
}

public Set<ShardRouting> getBatchedShardRoutings() {
return batchInfo.values().stream().map(ShardEntry::getShardRouting).collect(Collectors.toSet());
public List<ShardRouting> getBatchedShardRoutings() {
return batchInfo.values().stream().map(ShardEntry::getShardRouting).collect(Collectors.toList());
}

public Set<ShardId> getBatchedShards() {
Expand All @@ -611,7 +664,7 @@ public String getBatchId() {
return batchId;
}

public AsyncShardFetch<? extends BaseNodeResponse> getAsyncFetcher() {
public AsyncShardBatchFetch<? extends BaseNodeResponse, ? extends BaseShardResponse> getAsyncFetcher() {
return asyncBatch;
}

Expand All @@ -624,7 +677,7 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || o instanceof ShardsBatch == false) {
if (o instanceof ShardsBatch == false) {
return false;
}
ShardsBatch shardsBatch = (ShardsBatch) o;
Expand All @@ -646,15 +699,10 @@ public String toString() {
/**
* Holds information about a shard to be allocated in a batch.
*/
private class ShardEntry {
static class ShardEntry {

private final ShardAttributes shardAttributes;

public ShardEntry setShardRouting(ShardRouting shardRouting) {
this.shardRouting = shardRouting;
return this;
}

private ShardRouting shardRouting;

public ShardEntry(ShardAttributes shardAttributes, ShardRouting shardRouting) {
Expand All @@ -669,6 +717,11 @@ public ShardRouting getShardRouting() {
public ShardAttributes getShardAttributes() {
return shardAttributes;
}

public ShardEntry setShardRouting(ShardRouting shardRouting) {
this.shardRouting = shardRouting;
return this;
}
}

public int getNumberOfStartedShardBatches() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -36,8 +37,8 @@ public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator
PrimaryShardBatchAllocator primaryBatchShardAllocator = new PrimaryShardBatchAllocator() {
@Override
protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch> fetchData(
Set<ShardRouting> shardsEligibleForFetch,
Set<ShardRouting> inEligibleShards,
List<ShardRouting> eligibleShards,
List<ShardRouting> inEligibleShards,
RoutingAllocation allocation
) {
Map<DiscoveryNode, TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch> foundShards = new HashMap<>();
Expand All @@ -47,7 +48,7 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc
Map<ShardId, ShardRouting> shardsOnNode = entry.getValue();
HashMap<ShardId, TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard> adaptedResponse = new HashMap<>();

for (ShardRouting shardRouting : shardsEligibleForFetch) {
for (ShardRouting shardRouting : eligibleShards) {
ShardId shardId = shardRouting.shardId();
Set<String> ignoreNodes = allocation.getIgnoreNodes(shardId);

Expand Down Expand Up @@ -78,8 +79,8 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc

@Override
protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch> fetchData(
Set<ShardRouting> shardsEligibleForFetch,
Set<ShardRouting> inEligibleShards,
List<ShardRouting> eligibleShards,
List<ShardRouting> inEligibleShards,
RoutingAllocation allocation
) {
return new AsyncShardFetch.FetchResult<>(Collections.emptyMap(), Collections.emptyMap());
Expand Down

0 comments on commit 8a50ff2

Please sign in to comment.