Skip to content

Commit

Permalink
Support arbitrary enumerables in NpgsqlArrayConverter (#3290)
Browse files Browse the repository at this point in the history
Closes #3286
  • Loading branch information
roji authored Sep 21, 2024
1 parent 30cebf0 commit 8a0fca5
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ public override DbParameter CreateParameter(
// In queries which compose non-server-correlated LINQ operators over an array parameter (e.g. Where(b => ids.Skip(1)...) we
// get an enumerable parameter value that isn't an array/list - but those aren't supported at the Npgsql ADO level.
// Detect this here and evaluate the enumerable to get a fully materialized List.
// Note that when we have a value converter (e.g. for HashSet), we don't want to convert it to a List, since the value converter
// expects the original type.
// TODO: Make Npgsql support IList<> instead of only arrays and List<>
if (value is not null && !value.GetType().IsArrayOrGenericList())
if (value is not null && Converter is null && !value.GetType().IsArrayOrGenericList())
{
switch (value)
{
Expand Down
214 changes: 153 additions & 61 deletions src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,93 +105,185 @@ private static Expression<Func<TInput, TOutput>> ArrayConversionExpression<TInpu
p);
}

var input = Parameter(typeof(TInput), "value");
var input = Parameter(typeof(TInput), "input");
var convertedInput = input;
var output = Parameter(typeof(TConcreteOutput), "result");
var loopVariable = Parameter(typeof(int), "i");
var lengthVariable = Variable(typeof(int), "length");

var expressions = new List<Expression>();
var variables = new List<ParameterExpression>(4)
{
output,
lengthVariable,
};
var variables = new List<ParameterExpression> { output, lengthVariable };

Expression getInputLength;
Func<Expression, Expression> indexer;
Func<Expression, Expression>? indexer;

if (typeof(TInput).IsArray)
{
getInputLength = ArrayLength(input);
indexer = i => ArrayAccess(input, i);
}
else if (typeof(TInput).IsGenericType
&& typeof(TInput).GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>)))
{
getInputLength = Property(
input,
typeof(TInput).GetProperty("Count")
// If TInput is an interface (IList<T>), its Count property needs to be found on ICollection<T>
?? typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!);
indexer = i => Property(input, input.Type.FindIndexerProperty()!, i);
}
else
// The conversion is going to depend on what kind of input we have: array, list, collection, or arbitrary IEnumerable.
// For array/list we can get the length and index inside, so we can do an efficient for loop.
// For other ICollections (e.g. HashSet) we can get the length (and so pre-allocate the output), but we can't index; so we
// get an enumerator and use that.
// For arbitrary IEnumerable, we can't get the length so we can't preallocate output arrays; so we to call ToList() on it and then
// process that (note that we could avoid that when the output is a List rather than an array).
var inputInterfaces = input.Type.GetInterfaces();
switch (input.Type)
{
// Input collection isn't typed as an ICollection<T>; it can be *typed* as an IEnumerable<T>, but we only support concrete
// instances being ICollection<T>. Emit code that casts the type at runtime.
var iListType = typeof(IList<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]);
// Input is typed as an array - we can get its length and index into it
case { IsArray: true }:
getInputLength = ArrayLength(input);
indexer = i => ArrayAccess(input, i);
break;

// Input is typed as an IList - we can get its length and index into it
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>)):
{
getInputLength = Property(
input,
input.Type.GetProperty("Count")
// If TInput is an interface (IList<T>), its Count property needs to be found on ICollection<T>
?? typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!);
indexer = i => Property(input, input.Type.FindIndexerProperty()!, i);
break;
}

var convertedInput = Variable(iListType, "convertedInput");
variables.Add(convertedInput);
// Input is typed as an ICollection - we can get its length, but we can't index into it
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(ICollection<>)):
{
getInputLength = Property(
input, typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!);
indexer = null;
break;
}

expressions.Add(Assign(convertedInput, Convert(input, convertedInput.Type)));
// Input is typed as an IEnumerable - we can't get its length, and we can't index into it.
// All we can do is call ToList() on it and then process that.
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)):
{
// TODO: In theory, we could add runtime checks for array/list/collection, downcast for those cases and include
// the logic from the other switch cases here.
convertedInput = Variable(typeof(List<>).MakeGenericType(inputElementType), "convertedInput");
variables.Add(convertedInput);
expressions.Add(
Assign(
convertedInput,
Call(typeof(Enumerable).GetMethod(nameof(Enumerable.ToList))!.MakeGenericMethod(inputElementType), input)));
getInputLength = Property(convertedInput, convertedInput.Type.GetProperty("Count")!);
indexer = i => Property(convertedInput, convertedInput.Type.FindIndexerProperty()!, i);
break;
}

// TODO: Check and properly throw for non-IList<T>, e.g. set
getInputLength = Property(
convertedInput, typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!);
indexer = i => Property(convertedInput, iListType.FindIndexerProperty()!, i);
default:
throw new NotSupportedException($"Array value converter input type must be an IEnumerable, but is {typeof(TInput)}");
}

expressions.AddRange(
[
// Get the length of the input array or list
// var length = input.Length;
Assign(lengthVariable, getInputLength),

// Allocate an output array or list
// var result = new int[length];
Assign(
output, typeof(TConcreteOutput).IsArray
? NewArrayBounds(outputElementType, lengthVariable)
: typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength
? New(ctorWithLength, lengthVariable)
: New(typeof(TConcreteOutput).GetConstructor([])!)),

// Loop over the elements, applying the element converter on them one by one
// for (var i = 0; i < length; i++)
// {
// result[i] = input[i];
// }
// var length = input.Length;
Assign(lengthVariable, getInputLength),

// Allocate an output array or list
// var result = new int[length];
Assign(
output, typeof(TConcreteOutput).IsArray
? NewArrayBounds(outputElementType, lengthVariable)
: typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength
? New(ctorWithLength, lengthVariable)
: New(typeof(TConcreteOutput).GetConstructor([])!))
]);

if (indexer is not null)
{
// Good case: the input is an array or list, so we can index into it. Generate code for an efficient for loop, which applies
// the element converter on each element.
// for (var i = 0; i < length; i++)
// {
// result[i] = input[i];
// }
var counter = Parameter(typeof(int), "i");

expressions.Add(
ForLoop(
loopVar: loopVariable,
loopVar: counter,
initValue: Constant(0),
condition: LessThan(loopVariable, lengthVariable),
increment: AddAssign(loopVariable, Constant(1)),
condition: LessThan(counter, lengthVariable),
increment: AddAssign(counter, Constant(1)),
loopContent:
typeof(TConcreteOutput).IsArray
? Assign(
ArrayAccess(output, loopVariable),
ArrayAccess(output, counter),
elementConversionExpression is null
? indexer(loopVariable)
: Invoke(elementConversionExpression, indexer(loopVariable)))
? indexer(counter)
: Invoke(elementConversionExpression, indexer(counter)))
: Call(
output,
typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!,
elementConversionExpression is null
? indexer(loopVariable)
: Invoke(elementConversionExpression, indexer(loopVariable)))),
output
]);
? indexer(counter)
: Invoke(elementConversionExpression, indexer(counter)))));
}
else
{
// Bad case: the input is not an array or list, but is a collection (e.g. HashSet), so we can't index into it.
// Generate code for a less efficient enumerator-based iteration.
// enumerator = input.GetEnumerator();
// counter = 0;
// while (enumerator.MoveNext())
// {
// output[counter] = enumerator.Current;
// counter++;
// }
var enumerableType = typeof(IEnumerable<>).MakeGenericType(inputElementType);
var enumeratorType = typeof(IEnumerator<>).MakeGenericType(inputElementType);

var enumeratorVariable = Variable(enumeratorType, "enumerator");
var counterVariable = Variable(typeof(int), "variable");
variables.AddRange([enumeratorVariable, counterVariable]);

expressions.AddRange(
[
// enumerator = input.GetEnumerator();
Assign(enumeratorVariable, Call(input, enumerableType.GetMethod(nameof(IEnumerable<object>.GetEnumerator))!)),

// counter = 0;
Assign(counterVariable, Constant(0))
]);

var breakLabel = Label("LoopBreak");

var loop =
Loop(
IfThenElse(
Equal(Call(enumeratorVariable, typeof(IEnumerator).GetMethod(nameof(IEnumerator.MoveNext))!), Constant(true)),
Block(
typeof(TConcreteOutput).IsArray
// output[counter] = enumerator.Current;
? Assign(
ArrayAccess(output, counterVariable),
elementConversionExpression is null
? Property(enumeratorVariable, "Current")
: Invoke(elementConversionExpression, Property(enumeratorVariable, "Current")))
// output.Add(enumerator.Current);
: Call(
output,
typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!,
elementConversionExpression is null
? Property(enumeratorVariable, "Current")
: Invoke(elementConversionExpression, Property(enumeratorVariable, "Current"))),

// counter++;
AddAssign(counterVariable, Constant(1))),
Break(breakLabel)),
breakLabel);

expressions.Add(
TryFinally(
loop,
Call(enumeratorVariable, typeof(IDisposable).GetMethod(nameof(IDisposable.Dispose))!)));
}

// return output;
expressions.Add(output);

return Lambda<Func<TInput, TOutput>>(
// First, check if the given array value is null and return null immediately if so
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,26 @@ WHERE NOT (p."Int" = ANY (@__ints_0) AND p."Int" = ANY (@__ints_0) IS NOT NULL)
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_HashSet_with_value_converter_Contains(bool async)
{
HashSet<MyEnum> enums = [MyEnum.Value1, MyEnum.Value4];

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => enums.Contains(c.Enum)));

AssertSql(
"""
@__enums_0={ '0', '3' } (DbType = Object)

SELECT p."Id", p."Bool", p."Bools", p."DateTime", p."DateTimes", p."Enum", p."Enums", p."Int", p."Ints", p."NullableInt", p."NullableInts", p."NullableString", p."NullableStrings", p."String", p."Strings"
FROM "PrimitiveCollectionsEntity" AS p
WHERE p."Enum" = ANY (@__enums_0)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down

0 comments on commit 8a0fca5

Please sign in to comment.