Skip to content

Commit

Permalink
Merge pull request #746 from square/joel.suspend-lambdas
Browse files Browse the repository at this point in the history
Fix suspend lambdas not being supported
  • Loading branch information
JoelWilcox authored Sep 6, 2023
2 parents a1be7c6 + 63145c1 commit e5e01bf
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.kotlin.psi.KtUserType
import org.jetbrains.kotlin.psi.psiUtil.containingClass
import org.jetbrains.kotlin.psi.psiUtil.getChildOfType
import org.jetbrains.kotlin.psi.psiUtil.hasSuspendModifier
import org.jetbrains.kotlin.psi.psiUtil.parents
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameOrNull
import org.jetbrains.kotlin.types.DefinitelyNotNullType
Expand Down Expand Up @@ -304,6 +305,8 @@ public sealed class TypeReference {
?: emptyList(),
returnType = (returnTypeReference ?: fail())
.requireTypeName()
).copy(
suspending = type.modifierList?.hasSuspendModifier() ?: false
)
is KtNullableType -> {
(innerType ?: fail()).requireTypeName().copy(nullable = true)
Expand Down Expand Up @@ -644,6 +647,15 @@ private fun TypeName.lambdaFix(): TypeName {
// must be converted or their signatures won't match.
// see https://github.com/square/anvil/issues/400
val lambdaTypeName = this as? LambdaTypeName ?: return this
if (lambdaTypeName.isSuspending) {
// We short-circuit for suspend lambdas because otherwise we would need to represent them as
// something like `Function1<Continuation<String?>, *>`. This works fine when generating Java
// code like in Dagger, but proves problematic when generating Kotlin code and ends up requiring
// continuous casting of the arg to make the compiler happy again. It's possible this could end
// up failing for a Psi vs Descriptor use-case like the original ticket that required this
// lambda fix method but I haven't been able to find a failing repro for that yet.
return this
}

val allTypes = listOfNotNull(lambdaTypeName.receiver) +
lambdaTypeName.parameters.map { it.type } +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,48 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory
}
}

@Test fun `the factory function may require a suspend lambda type`() {
compile(
"""
package com.squareup.test
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
data class AssistedService @AssistedInject constructor(
val int: Int,
@Assisted val stringFactory: suspend (Int) -> String
)
@AssistedFactory
interface AssistedServiceFactory {
fun create(stringFactory: suspend (Int) -> String): AssistedService
}
"""
) {
val factoryImplClass = assistedServiceFactory.implClass()
val generatedFactoryInstance = assistedService.factoryClass().createInstance(Provider { 5 })
val factoryImplInstance = factoryImplClass.createInstance(generatedFactoryInstance)

val staticMethods = factoryImplClass.declaredMethods.filter { it.isStatic }
assertThat(staticMethods).hasSize(1)

val factoryProvider = staticMethods.single { it.name == "create" }
.invoke(null, generatedFactoryInstance) as Provider<*>
assertThat(factoryProvider.get()::class.java).isEqualTo(factoryImplClass)

val lambdaArg: suspend (Int) -> String = { num: Int -> num.toString() }

val assistedServiceInstance = factoryImplClass.declaredMethods
.filterNot { it.isStatic }
.last { it.name == "create" }
.invoke(factoryImplInstance, lambdaArg)

assertThat(assistedServiceInstance).isEqualTo(assistedService.createInstance(5, lambdaArg))
}
}

@Test fun `the factory function may require a Function type`() {
compile(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,38 @@ public final class AssistedService_Factory {
}
}

@Test fun `a factory class is generated with a suspend lambda assisted parameter`() {
compile(
"""
package com.squareup.test
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
data class AssistedService @AssistedInject constructor(
@Assisted val action: suspend () -> String?
)
""",
) {
val factoryClass = assistedService.factoryClass()

val constructor = factoryClass.declaredConstructors.single()
assertThat(constructor.parameterTypes.toList()).isEmpty()

val staticMethods = factoryClass.declaredMethods.filter { it.isStatic }
assertThat(staticMethods).hasSize(2)

val factoryInstance = staticMethods.single { it.name == "create" }
.invoke(null)
assertThat(factoryInstance::class.java).isEqualTo(factoryClass)

val action: suspend () -> String? = { "Hello " }
val newInstance = staticMethods.single { it.name == "newInstance" }
.invoke(null, action)
assertThat(factoryInstance.invokeGet(action)).isEqualTo(newInstance)
}
}

@Test fun `a factory class is generated without any parameter`() {
compile(
"""
Expand Down Expand Up @@ -642,7 +674,7 @@ public final class AssistedService_Factory {

private fun compile(
@Language("kotlin") vararg sources: String,
block: JvmCompilationResult.() -> Unit = { }
block: JvmCompilationResult.() -> Unit = { },
): JvmCompilationResult = compileAnvil(
sources = sources,
enableDaggerAnnotationProcessor = useDagger,
Expand Down

0 comments on commit e5e01bf

Please sign in to comment.