diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TypeReference.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TypeReference.kt index de759bb12..f7a76511a 100644 --- a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TypeReference.kt +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TypeReference.kt @@ -35,6 +35,7 @@ import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.psi.KtUserType import org.jetbrains.kotlin.psi.psiUtil.containingClass import org.jetbrains.kotlin.psi.psiUtil.getChildOfType +import org.jetbrains.kotlin.psi.psiUtil.hasSuspendModifier import org.jetbrains.kotlin.psi.psiUtil.parents import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameOrNull import org.jetbrains.kotlin.types.DefinitelyNotNullType @@ -304,6 +305,8 @@ public sealed class TypeReference { ?: emptyList(), returnType = (returnTypeReference ?: fail()) .requireTypeName() + ).copy( + suspending = type.modifierList?.hasSuspendModifier() ?: false ) is KtNullableType -> { (innerType ?: fail()).requireTypeName().copy(nullable = true) @@ -644,6 +647,15 @@ private fun TypeName.lambdaFix(): TypeName { // must be converted or their signatures won't match. // see https://github.com/square/anvil/issues/400 val lambdaTypeName = this as? LambdaTypeName ?: return this + if (lambdaTypeName.isSuspending) { + // We short-circuit for suspend lambdas because otherwise we would need to represent them as + // something like `Function1, *>`. This works fine when generating Java + // code like in Dagger, but proves problematic when generating Kotlin code and ends up requiring + // continuous casting of the arg to make the compiler happy again. It's possible this could end + // up failing for a Psi vs Descriptor use-case like the original ticket that required this + // lambda fix method but I haven't been able to find a failing repro for that yet. + return this + } val allTypes = listOfNotNull(lambdaTypeName.receiver) + lambdaTypeName.parameters.map { it.type } + diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt index 044da9ae1..8690c6ebf 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt @@ -326,6 +326,48 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory } } + @Test fun `the factory function may require a suspend lambda type`() { + compile( + """ + package com.squareup.test + + import dagger.assisted.Assisted + import dagger.assisted.AssistedFactory + import dagger.assisted.AssistedInject + + data class AssistedService @AssistedInject constructor( + val int: Int, + @Assisted val stringFactory: suspend (Int) -> String + ) + + @AssistedFactory + interface AssistedServiceFactory { + fun create(stringFactory: suspend (Int) -> String): AssistedService + } + """ + ) { + val factoryImplClass = assistedServiceFactory.implClass() + val generatedFactoryInstance = assistedService.factoryClass().createInstance(Provider { 5 }) + val factoryImplInstance = factoryImplClass.createInstance(generatedFactoryInstance) + + val staticMethods = factoryImplClass.declaredMethods.filter { it.isStatic } + assertThat(staticMethods).hasSize(1) + + val factoryProvider = staticMethods.single { it.name == "create" } + .invoke(null, generatedFactoryInstance) as Provider<*> + assertThat(factoryProvider.get()::class.java).isEqualTo(factoryImplClass) + + val lambdaArg: suspend (Int) -> String = { num: Int -> num.toString() } + + val assistedServiceInstance = factoryImplClass.declaredMethods + .filterNot { it.isStatic } + .last { it.name == "create" } + .invoke(factoryImplInstance, lambdaArg) + + assertThat(assistedServiceInstance).isEqualTo(assistedService.createInstance(5, lambdaArg)) + } + } + @Test fun `the factory function may require a Function type`() { compile( """ diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt index f13bfd0f2..ad5f114d0 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt @@ -96,6 +96,38 @@ public final class AssistedService_Factory { } } + @Test fun `a factory class is generated with a suspend lambda assisted parameter`() { + compile( + """ + package com.squareup.test + + import dagger.assisted.Assisted + import dagger.assisted.AssistedInject + + data class AssistedService @AssistedInject constructor( + @Assisted val action: suspend () -> String? + ) + """, + ) { + val factoryClass = assistedService.factoryClass() + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).isEmpty() + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + assertThat(staticMethods).hasSize(2) + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val action: suspend () -> String? = { "Hello " } + val newInstance = staticMethods.single { it.name == "newInstance" } + .invoke(null, action) + assertThat(factoryInstance.invokeGet(action)).isEqualTo(newInstance) + } + } + @Test fun `a factory class is generated without any parameter`() { compile( """ @@ -642,7 +674,7 @@ public final class AssistedService_Factory { private fun compile( @Language("kotlin") vararg sources: String, - block: JvmCompilationResult.() -> Unit = { } + block: JvmCompilationResult.() -> Unit = { }, ): JvmCompilationResult = compileAnvil( sources = sources, enableDaggerAnnotationProcessor = useDagger,