Skip to content

Commit

Permalink
Merge pull request #1048 from zachpainter77/master
Browse files Browse the repository at this point in the history
Add Support For Generic Handlers With Multiple Generic Type Parameters
  • Loading branch information
jbogard authored Jul 16, 2024
2 parents 3b8bf44 + 811ce54 commit cac76be
Show file tree
Hide file tree
Showing 7 changed files with 1,129 additions and 488 deletions.
935 changes: 480 additions & 455 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}

ServiceRegistrar.AddMediatRClasses(services, configuration);
ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration);

ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration);

ServiceRegistrar.AddRequiredServices(services, configuration);

Expand Down
132 changes: 111 additions & 21 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,50 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using MediatR.Pipeline;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;

namespace MediatR.Registration;

public static class ServiceRegistrar
{
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration)
{
{
private static int MaxGenericTypeParameters;
private static int MaxTypesClosing;
private static int MaxGenericTypeRegistrations;
private static int RegistrationTimeout;

public static void SetGenericRequestHandlerRegistrationLimitations(MediatRServiceConfiguration configuration)
{
MaxGenericTypeParameters = configuration.MaxGenericTypeParameters;
MaxTypesClosing = configuration.MaxTypesClosing;
MaxGenericTypeRegistrations = configuration.MaxGenericTypeRegistrations;
RegistrationTimeout = configuration.RegistrationTimeout;
}

public static void AddMediatRClassesWithTimeout(IServiceCollection services, MediatRServiceConfiguration configuration)
{
using(var cts = new CancellationTokenSource(RegistrationTimeout))
{
try
{
AddMediatRClasses(services, configuration, cts.Token);
}
catch (OperationCanceledException)
{
throw new TimeoutException("The generic handler registration process timed out.");
}
}
}

public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration, CancellationToken cancellationToken = default)
{

var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration, cancellationToken);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration, cancellationToken);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration);
Expand Down Expand Up @@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
IServiceCollection services,
IEnumerable<Assembly> assembliesToScan,
bool addIfAlreadyExists,
MediatRServiceConfiguration configuration)
MediatRServiceConfiguration configuration,
CancellationToken cancellationToken = default)
{
var concretions = new List<Type>();
var interfaces = new List<Type>();
Expand All @@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa

var types = assembliesToScan
.SelectMany(a => a.DefinedTypes)
.Where(t => !t.ContainsGenericParameters || configuration.RegisterGenericHandlers)
.Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any())
.Where(configuration.TypeEvaluator)
.ToList();
.ToList();

foreach (var type in types)
{
Expand Down Expand Up @@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
foreach (var @interface in genericInterfaces)
{
var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList();
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan);
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan, cancellationToken);
}
}

Expand Down Expand Up @@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>

private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation)
{
var closingType = concreteGenericTRequest.GetGenericArguments().First();
var closingTypes = concreteGenericTRequest.GetGenericArguments();

var concreteTResponse = concreteGenericTRequest.GetInterfaces()
.FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>))
Expand All @@ -187,33 +219,90 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(
typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) :
typeDefinition.MakeGenericType(concreteGenericTRequest);

return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType));
return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingTypes));
}

private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan)
private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
{
var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints();

var typesThatCanClose = assembliesToScan
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type)))
.ToList();
//request generic type constraints
var constraintsForEachParameter = openRequestHandlerImplementation
.GetGenericArguments()
.Select(x => x.GetGenericParameterConstraints())
.ToList();

if (constraintsForEachParameter.Count > 2 && constraintsForEachParameter.Any(constraints => !constraints.Where(x => x.IsInterface || x.IsClass).Any()))
throw new ArgumentException($"Error registering the generic handler type: {openRequestHandlerImplementation.FullName}. When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class.");

var typesThatCanCloseForEachParameter = constraintsForEachParameter
.Select(constraints => assembliesToScan
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))).ToList()
).ToList();

var requestType = openRequestHandlerInterface.GenericTypeArguments.First();

if (requestType.IsGenericParameter)
return null;

var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition();

var combinations = GenerateCombinations(requestType, typesThatCanCloseForEachParameter, 0, cancellationToken);

return combinations.Select(types => requestGenericTypeDefinition.MakeGenericType(types.ToArray())).ToList();
}

// Method to generate combinations recursively
public static List<List<Type>> GenerateCombinations(Type requestType, List<List<Type>> lists, int depth = 0, CancellationToken cancellationToken = default)
{
if (depth == 0)
{
// Initial checks
if (MaxGenericTypeParameters > 0 && lists.Count > MaxGenericTypeParameters)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The number of generic type parameters exceeds the maximum allowed ({MaxGenericTypeParameters}).");

foreach (var list in lists)
{
if (MaxTypesClosing > 0 && list.Count > MaxTypesClosing)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({MaxTypesClosing}).");
}

// Calculate the total number of combinations
long totalCombinations = 1;
foreach (var list in lists)
{
totalCombinations *= list.Count;
if (MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The total number of generic type registrations exceeds the maximum allowed ({MaxGenericTypeRegistrations}).");
}
}

if (depth >= lists.Count)
return new List<List<Type>> { new List<Type>() };

cancellationToken.ThrowIfCancellationRequested();

return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList();
var currentList = lists[depth];
var childCombinations = GenerateCombinations(requestType, lists, depth + 1, cancellationToken);
var combinations = new List<List<Type>>();

foreach (var item in currentList)
{
foreach (var childCombination in childCombinations)
{
var currentCombination = new List<Type> { item };
currentCombination.AddRange(childCombination);
combinations.Add(currentCombination);
}
}

return combinations;
}

private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan)
private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
{
foreach (var concretion in concretions)
{
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan);
{
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan, cancellationToken);

if (concreteRequests is null)
continue;
Expand All @@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List<T

foreach (var (Service, Implementation) in registrationTypes)
{
cancellationToken.ThrowIfCancellationRequested();
services.AddTransient(Service, Implementation);
}
}
Expand Down
Loading

0 comments on commit cac76be

Please sign in to comment.