Skip to content

Commit

Permalink
Fix binding method without body inside of interface wrongly assumed t…
Browse files Browse the repository at this point in the history
…o not be abstract in BindsMethodValidator
  • Loading branch information
IlyaGulya committed Dec 31, 2023
1 parent 89d49d7 commit e153081
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.Modality.ABSTRACT
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.descriptors.PropertyDescriptor
import org.jetbrains.kotlin.descriptors.findClassAcrossModuleDependencies
import org.jetbrains.kotlin.lexer.KtTokens.ABSTRACT_KEYWORD
import org.jetbrains.kotlin.lexer.KtTokens.INTERNAL_KEYWORD
import org.jetbrains.kotlin.lexer.KtTokens.PRIVATE_KEYWORD
Expand All @@ -40,6 +41,9 @@ import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.resolve.descriptorUtil.parents
import org.jetbrains.kotlin.resolve.scopes.DescriptorKindFilter
import org.jetbrains.kotlin.resolve.scopes.getDescriptorsFiltered
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.TypeUtils
import org.jetbrains.kotlin.types.checker.KotlinTypeChecker
import kotlin.LazyThreadSafetyMode.NONE

/**
Expand Down Expand Up @@ -85,6 +89,8 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
*/
public abstract fun enclosingClassesWithSelf(): List<ClassReference>

internal abstract fun getDescriptor(): ClassDescriptor

public fun enclosingClass(): ClassReference? {
val classes = enclosingClassesWithSelf()
val index = classes.indexOf(this)
Expand Down Expand Up @@ -220,6 +226,14 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere

override fun enclosingClassesWithSelf(): List<Psi> = enclosingClassesWithSelf

override fun getDescriptor(): ClassDescriptor {
return module.findClassAcrossModuleDependencies(this.classId)
?: throw AnvilCompilationException(
element = clazz,
message = "Couldn't find descriptor for class $fqName.",
)
}

@Suppress("UNCHECKED_CAST")
override fun innerClasses(): List<Psi> =
super.innerClasses() as List<Psi>
Expand Down Expand Up @@ -323,6 +337,10 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere

override fun enclosingClassesWithSelf(): List<Descriptor> = enclosingClassesWithSelf

override fun getDescriptor(): ClassDescriptor {
return clazz
}

@Suppress("UNCHECKED_CAST")
override fun innerClasses(): List<Descriptor> =
super.innerClasses() as List<Descriptor>
Expand Down Expand Up @@ -433,9 +451,28 @@ public fun AnvilCompilationExceptionClassReference(
message = message,
cause = cause,
)

is Descriptor -> AnvilCompilationException(
classDescriptor = classReference.clazz,
message = message,
cause = cause,
)
}

fun TypeReference.isAssignableTo(
other: TypeReference
): Boolean {
return KotlinTypeChecker.DEFAULT.isSubtypeOf(
getFullKotlinType(),
other.getFullKotlinType(),
)
}

fun TypeReference.getFullKotlinType(): KotlinType {
val descriptor = asClassReference().getDescriptor()
val typeArguments = this.unwrappedTypes.map { it.getFullKotlinType() }
return TypeUtils.substituteParameters(
descriptor,
typeArguments,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSTypeReference
import com.google.devtools.ksp.symbol.Modifier
import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker
import com.squareup.anvil.compiler.api.AnvilContext
import com.squareup.anvil.compiler.api.CodeGenerator
Expand Down Expand Up @@ -59,7 +58,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {
internal fun bindsParameterMustBeAssignable(
paramSuperTypeNames: List<String>,
returnTypeName: String,
parameterName: String?,
parameterType: String?,
): String {
val superTypesMessage = if (paramSuperTypeNames.isEmpty()) {
"has no supertypes."
Expand All @@ -69,7 +68,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {

return "@Binds methods' parameter type must be assignable to the return type. " +
"Expected binding of type $returnTypeName but impl parameter of type " +
"$parameterName $superTypesMessage"
"$parameterType $superTypesMessage"
}
}

Expand Down Expand Up @@ -102,7 +101,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {
private val resolver: Resolver
) {
internal fun validateBindsFunction(function: KSFunctionDeclaration) {
if (!function.modifiers.contains(Modifier.ABSTRACT)) {
if (!function.isAbstract) {
throw KspAnvilException(
message = Errors.BINDS_MUST_BE_ABSTRACT,
node = function,
Expand All @@ -129,31 +128,32 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {
node = function,
)

val parameterSuperTypes = function.parameterSuperTypes()
val receiverSuperTypes = function.receiverSuperTypes()
when (returnType) {
in function.parameterSuperTypes() -> return
in function.receiverSuperTypes() -> return
in parameterSuperTypes -> return
in receiverSuperTypes -> return
}

val returnTypeRef = returnType.declaration.simpleName.asString()
val paramSuperTypes =
function
.parameterSuperTypes()
.ifEmpty { function.receiverSuperTypes() }
val actualParameterSuperTypes =
parameterSuperTypes
.ifEmpty { receiverSuperTypes }

val lastParameterType = function.parameters.lastOrNull()?.type
val parameterRef = lastParameterType ?: function.extensionReceiver
val param = parameterRef?.resolve()

val paramSuperTypeNames =
paramSuperTypes
val actualParameterSuperTypeNames =
actualParameterSuperTypes
.map { it.declaration.simpleName.asString() }
.toList()

throw KspAnvilException(
message = Errors.bindsParameterMustBeAssignable(
paramSuperTypeNames = paramSuperTypeNames,
paramSuperTypeNames = actualParameterSuperTypeNames,
returnTypeName = returnTypeRef,
parameterName = param?.declaration?.simpleName?.getShortName(),
parameterType = param?.declaration?.simpleName?.getShortName(),
),
node = function,
)
Expand Down Expand Up @@ -232,7 +232,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {

if (!function.parameterMatchesReturnType() && !function.receiverMatchesReturnType()) {
val returnType = function.returnType().asClassReference().shortName
val paramSuperTypes =
val actualParameterSuperTypes =
function
.parameterSuperTypes()
.ifEmpty { function.receiverSuperTypes() }
Expand All @@ -241,9 +241,9 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {

throw AnvilCompilationExceptionFunctionReference(
message = Errors.bindsParameterMustBeAssignable(
paramSuperTypeNames = paramSuperTypes.drop(1),
paramSuperTypeNames = actualParameterSuperTypes.drop(1),
returnTypeName = returnType,
parameterName = paramSuperTypes.first(),
parameterType = actualParameterSuperTypes.first(),
),
functionReference = function,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,30 @@ class BindsMethodValidatorTest(
}
}

@Test
fun `binding inside interface is valid`() {
compile(
"""
package com.squareup.test
import dagger.Binds
import dagger.Module
import javax.inject.Inject
interface Foo : Bar
interface Bar
@Module
interface BarModule {
@Binds
fun bindsBar(foo: Foo): Bar
}
""",
) {
assertThat(exitCode).isEqualTo(OK)
}
}

private fun compile(
@Language("kotlin") vararg sources: String,
previousCompilationResult: JvmCompilationResult? = null,
Expand Down

0 comments on commit e153081

Please sign in to comment.