Skip to content

Commit

Permalink
Fix suspend lambdas not being supported
Browse files Browse the repository at this point in the history
Fixes #745.

This use-case primarily popped up with assisted parameter and factory generation. We would get an error like:
e: Class 'AssistedServiceFactory_Impl' is not abstract and does not implement abstract member public abstract fun create(stringFactory: suspend (Int) -> String): AssistedService defined in com.squareup.test.AssistedServiceFactory
e: AssistedServiceFactory_Impl.kt:25:3 'create' overrides nothing

This is because we manually override the generated code for lambdas to address differences between Psi and Descriptors. There is a possibility that we could continue overriding the default generated code for suspend lambdas with modifications to what we currently do, but I found doing so to introduce problematic casting requirements that would spread beyond this isolated area of handling the lambdas themselves. Instead, we can first change to generating the default expected suspend lambda signature to fix this primary use-case, and wait to see if any similar Psi vs Descriptor use-cases pop up later. Thus far I haven't been able to find any reproable test cases where this solution doesn't work.
  • Loading branch information
JoelWilcox committed Sep 6, 2023
1 parent a1be7c6 commit 63145c1
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.kotlin.psi.KtUserType
import org.jetbrains.kotlin.psi.psiUtil.containingClass
import org.jetbrains.kotlin.psi.psiUtil.getChildOfType
import org.jetbrains.kotlin.psi.psiUtil.hasSuspendModifier
import org.jetbrains.kotlin.psi.psiUtil.parents
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameOrNull
import org.jetbrains.kotlin.types.DefinitelyNotNullType
Expand Down Expand Up @@ -304,6 +305,8 @@ public sealed class TypeReference {
?: emptyList(),
returnType = (returnTypeReference ?: fail())
.requireTypeName()
).copy(
suspending = type.modifierList?.hasSuspendModifier() ?: false
)
is KtNullableType -> {
(innerType ?: fail()).requireTypeName().copy(nullable = true)
Expand Down Expand Up @@ -644,6 +647,15 @@ private fun TypeName.lambdaFix(): TypeName {
// must be converted or their signatures won't match.
// see https://github.com/square/anvil/issues/400
val lambdaTypeName = this as? LambdaTypeName ?: return this
if (lambdaTypeName.isSuspending) {
// We short-circuit for suspend lambdas because otherwise we would need to represent them as
// something like `Function1<Continuation<String?>, *>`. This works fine when generating Java
// code like in Dagger, but proves problematic when generating Kotlin code and ends up requiring
// continuous casting of the arg to make the compiler happy again. It's possible this could end
// up failing for a Psi vs Descriptor use-case like the original ticket that required this
// lambda fix method but I haven't been able to find a failing repro for that yet.
return this
}

val allTypes = listOfNotNull(lambdaTypeName.receiver) +
lambdaTypeName.parameters.map { it.type } +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,48 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory
}
}

@Test fun `the factory function may require a suspend lambda type`() {
compile(
"""
package com.squareup.test
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
data class AssistedService @AssistedInject constructor(
val int: Int,
@Assisted val stringFactory: suspend (Int) -> String
)
@AssistedFactory
interface AssistedServiceFactory {
fun create(stringFactory: suspend (Int) -> String): AssistedService
}
"""
) {
val factoryImplClass = assistedServiceFactory.implClass()
val generatedFactoryInstance = assistedService.factoryClass().createInstance(Provider { 5 })
val factoryImplInstance = factoryImplClass.createInstance(generatedFactoryInstance)

val staticMethods = factoryImplClass.declaredMethods.filter { it.isStatic }
assertThat(staticMethods).hasSize(1)

val factoryProvider = staticMethods.single { it.name == "create" }
.invoke(null, generatedFactoryInstance) as Provider<*>
assertThat(factoryProvider.get()::class.java).isEqualTo(factoryImplClass)

val lambdaArg: suspend (Int) -> String = { num: Int -> num.toString() }

val assistedServiceInstance = factoryImplClass.declaredMethods
.filterNot { it.isStatic }
.last { it.name == "create" }
.invoke(factoryImplInstance, lambdaArg)

assertThat(assistedServiceInstance).isEqualTo(assistedService.createInstance(5, lambdaArg))
}
}

@Test fun `the factory function may require a Function type`() {
compile(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,38 @@ public final class AssistedService_Factory {
}
}

@Test fun `a factory class is generated with a suspend lambda assisted parameter`() {
compile(
"""
package com.squareup.test
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
data class AssistedService @AssistedInject constructor(
@Assisted val action: suspend () -> String?
)
""",
) {
val factoryClass = assistedService.factoryClass()

val constructor = factoryClass.declaredConstructors.single()
assertThat(constructor.parameterTypes.toList()).isEmpty()

val staticMethods = factoryClass.declaredMethods.filter { it.isStatic }
assertThat(staticMethods).hasSize(2)

val factoryInstance = staticMethods.single { it.name == "create" }
.invoke(null)
assertThat(factoryInstance::class.java).isEqualTo(factoryClass)

val action: suspend () -> String? = { "Hello " }
val newInstance = staticMethods.single { it.name == "newInstance" }
.invoke(null, action)
assertThat(factoryInstance.invokeGet(action)).isEqualTo(newInstance)
}
}

@Test fun `a factory class is generated without any parameter`() {
compile(
"""
Expand Down Expand Up @@ -642,7 +674,7 @@ public final class AssistedService_Factory {

private fun compile(
@Language("kotlin") vararg sources: String,
block: JvmCompilationResult.() -> Unit = { }
block: JvmCompilationResult.() -> Unit = { },
): JvmCompilationResult = compileAnvil(
sources = sources,
enableDaggerAnnotationProcessor = useDagger,
Expand Down

0 comments on commit 63145c1

Please sign in to comment.