From fd5e6ab55a63010b3a8305e13afcc6a3fa03d297 Mon Sep 17 00:00:00 2001 From: Ralf Wondratschek Date: Thu, 12 Nov 2020 10:22:59 -0800 Subject: [PATCH] Support property as providers in Dagger modules. Fixes #149 --- .../anvil/compiler/codegen/PsiUtils.kt | 14 +- .../dagger/ProvidesMethodFactoryGenerator.kt | 77 +++- .../ProvidesMethodFactoryGeneratorTest.kt | 352 ++++++++++++++++++ 3 files changed, 418 insertions(+), 25 deletions(-) diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/PsiUtils.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/PsiUtils.kt index 84d4ffa49..a4ddda9ae 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/PsiUtils.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/PsiUtils.kt @@ -28,6 +28,7 @@ import org.jetbrains.kotlin.psi.KtNameReferenceExpression import org.jetbrains.kotlin.psi.KtNamedDeclaration import org.jetbrains.kotlin.psi.KtNamedFunction import org.jetbrains.kotlin.psi.KtNullableType +import org.jetbrains.kotlin.psi.KtProperty import org.jetbrains.kotlin.psi.KtTypeArgumentList import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.psi.KtUserType @@ -355,15 +356,18 @@ internal fun ModuleDescriptor.findClassOrTypeAlias( internal fun KtClassOrObject.functions( includeCompanionObjects: Boolean -): List { +): List = classBodies(includeCompanionObjects).flatMap { it.functions } + +internal fun KtClassOrObject.properties( + includeCompanionObjects: Boolean +): List = classBodies(includeCompanionObjects).flatMap { it.properties } + +private fun KtClassOrObject.classBodies(includeCompanionObjects: Boolean): List { val elements = children.toMutableList() if (includeCompanionObjects) { elements += companionObjects.flatMap { it.children.toList() } } - - return elements - .filterIsInstance() - .flatMap { it.functions } + return elements.filterIsInstance() } fun KtTypeReference.isNullable(): Boolean = typeElement is KtNullableType diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt index b69bd39cf..5a91c36c2 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt @@ -7,10 +7,12 @@ import com.squareup.anvil.compiler.codegen.addGeneratedByComment import com.squareup.anvil.compiler.codegen.asArgumentList import com.squareup.anvil.compiler.codegen.asClassName import com.squareup.anvil.compiler.codegen.classesAndInnerClasses +import com.squareup.anvil.compiler.codegen.findAnnotation import com.squareup.anvil.compiler.codegen.functions import com.squareup.anvil.compiler.codegen.hasAnnotation import com.squareup.anvil.compiler.codegen.isNullable import com.squareup.anvil.compiler.codegen.mapToParameter +import com.squareup.anvil.compiler.codegen.properties import com.squareup.anvil.compiler.codegen.requireFqName import com.squareup.anvil.compiler.codegen.requireTypeName import com.squareup.anvil.compiler.codegen.requireTypeReference @@ -32,10 +34,11 @@ import com.squareup.kotlinpoet.jvm.jvmStatic import dagger.internal.Factory import dagger.internal.Preconditions import org.jetbrains.kotlin.descriptors.ModuleDescriptor +import org.jetbrains.kotlin.psi.KtCallableDeclaration import org.jetbrains.kotlin.psi.KtClassOrObject import org.jetbrains.kotlin.psi.KtFile -import org.jetbrains.kotlin.psi.KtNamedFunction import org.jetbrains.kotlin.psi.KtObjectDeclaration +import org.jetbrains.kotlin.psi.KtProperty import org.jetbrains.kotlin.psi.psiUtil.parents import java.io.File import java.util.Locale.US @@ -73,6 +76,17 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { .forEach { function -> generateFactoryClass(codeGenDir, module, clazz, function) } + + clazz + .properties(includeCompanionObjects = true) + .asSequence() + .filter { property -> + // Must be '@get:Provides'. + property.findAnnotation(daggerProvidesFqName)?.useSiteTarget?.text == "get" + } + .forEach { property -> + generateFactoryClass(codeGenDir, module, clazz, property) + } } } @@ -81,30 +95,48 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { codeGenDir: File, module: ModuleDescriptor, clazz: KtClassOrObject, - function: KtNamedFunction + declaration: KtCallableDeclaration ): GeneratedFile { - val isCompanionObject = function.parents + val isCompanionObject = declaration.parents .filterIsInstance() .firstOrNull() ?.isCompanion() ?: false val isObject = isCompanionObject || clazz is KtObjectDeclaration + val isProperty = declaration is KtProperty + val packageName = clazz.containingKtFile.packageFqName.asString() - val className = "${clazz.generateClassName()}_" + - (if (isCompanionObject) "Companion_" else "") + - "${function.requireFqName().shortName().asString().capitalize(US)}Factory" - val functionName = function.nameAsSafeName.asString() + val className = buildString { + append(clazz.generateClassName()) + append('_') + if (isCompanionObject) { + append("Companion_") + } + if (isProperty) { + append("Get") + } + append(declaration.requireFqName().shortName().asString().capitalize(US)) + append("Factory") + } - val parameters = function.valueParameters.mapToParameter(module) + val callableName = declaration.nameAsSafeName.asString() - val returnType = function.requireTypeReference().requireTypeName(module) - .withJvmSuppressWildcardsIfNeeded(function) - val returnTypeIsNullable = function.typeReference?.isNullable() ?: false + val parameters = declaration.valueParameters.mapToParameter(module) + + val returnType = declaration.requireTypeReference().requireTypeName(module) + .withJvmSuppressWildcardsIfNeeded(declaration) + val returnTypeIsNullable = declaration.typeReference?.isNullable() ?: false val factoryClass = ClassName(packageName, className) val moduleClass = clazz.asClassName() + val byteCodeFunctionName = if (isProperty) { + "get" + callableName.capitalize(US) + } else { + callableName + } + val content = FileSpec.builder(packageName, className) .apply { val canGenerateAnObject = isObject && parameters.isEmpty() @@ -158,7 +190,7 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { asProvider = true, includeModule = !isObject ) - addStatement("return $functionName($argumentList)") + addStatement("return $byteCodeFunctionName($argumentList)") } .build() ) @@ -191,7 +223,7 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { .build() ) .addFunction( - FunSpec.builder(functionName) + FunSpec.builder(byteCodeFunctionName) .jvmStatic() .apply { if (!isObject) { @@ -204,30 +236,35 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { type = parameter.originalTypeName ) } - val argumentsWithoutModule = parameters.joinToString { it.name } + + val argumentsWithoutModule = if (isProperty) { + "" + } else { + "(${parameters.joinToString { it.name }})" + } when { isObject && returnTypeIsNullable -> addStatement( - "return %T.$functionName($argumentsWithoutModule)", + "return %T.$callableName$argumentsWithoutModule", moduleClass ) isObject && !returnTypeIsNullable -> addStatement( - "return %T.checkNotNull(%T.$functionName" + - "($argumentsWithoutModule), %S)", + "return %T.checkNotNull(%T.$callableName" + + "$argumentsWithoutModule, %S)", Preconditions::class, moduleClass, "Cannot return null from a non-@Nullable @Provides method" ) !isObject && returnTypeIsNullable -> addStatement( - "return module.$functionName($argumentsWithoutModule)" + "return module.$callableName$argumentsWithoutModule" ) !isObject && !returnTypeIsNullable -> addStatement( - "return %T.checkNotNull(module.$functionName" + - "($argumentsWithoutModule), %S)", + "return %T.checkNotNull(module.$callableName" + + "$argumentsWithoutModule, %S)", Preconditions::class, "Cannot return null from a non-@Nullable @Provides method" ) diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/ProvidesMethodFactoryGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/ProvidesMethodFactoryGeneratorTest.kt index 6dd4d6db6..bd205f449 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/ProvidesMethodFactoryGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/ProvidesMethodFactoryGeneratorTest.kt @@ -2440,6 +2440,358 @@ public final class DaggerComponentInterface implements ComponentInterface { } } + @Test fun `a factory class is generated for provided properties`() { + /* +package com.squareup.test; + +import dagger.internal.Factory; +import dagger.internal.Preconditions; +import javax.annotation.processing.Generated; + +@Generated( + value = "dagger.internal.codegen.ComponentProcessor", + comments = "https://dagger.dev" +) +@SuppressWarnings({ + "unchecked", + "rawtypes" +}) +public final class DaggerModule1_GetStringFactory implements Factory { + private final DaggerModule1 module; + + public DaggerModule1_GetStringFactory(DaggerModule1 module) { + this.module = module; + } + + @Override + public String get() { + return getString(module); + } + + public static DaggerModule1_GetStringFactory create(DaggerModule1 module) { + return new DaggerModule1_GetStringFactory(module); + } + + public static String getString(DaggerModule1 instance) { + return Preconditions.checkNotNull(instance.getString(), "Cannot return null from a non-@Nullable @Provides method"); + } +} + */ + compile( + """ + package com.squareup.test + + import dagger.Module + import dagger.Provides + + @Module + class DaggerModule1 { + @get:Provides val string: String = "abc" + } + """ + ) { + val factoryClass = daggerModule1.moduleFactoryClass("getString") + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).containsExactly(daggerModule1) + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + + val module = daggerModule1.newInstance() + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null, module) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val providedString = staticMethods.single { it.name == "getString" } + .invoke(null, module) as String + + assertThat(providedString).isEqualTo("abc") + assertThat((factoryInstance as Factory).get()).isEqualTo("abc") + } + } + + @Test fun `a factory class is generated for provided properties in an object module`() { + /* +package com.squareup.test; + +import dagger.internal.Factory; +import dagger.internal.Preconditions; +import javax.annotation.processing.Generated; + +@Generated( + value = "dagger.internal.codegen.ComponentProcessor", + comments = "https://dagger.dev" +) +@SuppressWarnings({ + "unchecked", + "rawtypes" +}) +public final class DaggerModule1_GetStringFactory implements Factory { + @Override + public String get() { + return getString(); + } + + public static DaggerModule1_GetStringFactory create() { + return InstanceHolder.INSTANCE; + } + + public static String getString() { + return Preconditions.checkNotNull(DaggerModule1.INSTANCE.getString(), "Cannot return null from a non-@Nullable @Provides method"); + } + + private static final class InstanceHolder { + private static final DaggerModule1_GetStringFactory INSTANCE = new DaggerModule1_GetStringFactory(); + } +} + */ + compile( + """ + package com.squareup.test + + import dagger.Module + import dagger.Provides + + @Module + object DaggerModule1 { + @get:Provides val string: String = "abc" + } + """ + ) { + val factoryClass = daggerModule1.moduleFactoryClass("getString") + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).isEmpty() + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val providedString = staticMethods.single { it.name == "getString" } + .invoke(null) as String + + assertThat(providedString).isEqualTo("abc") + assertThat((factoryInstance as Factory).get()).isEqualTo("abc") + } + } + + @Test fun `a factory class is generated for provided properties in a companion object module`() { + /* +package com.squareup.test; + +import dagger.internal.Factory; +import dagger.internal.Preconditions; +import javax.annotation.processing.Generated; + +@Generated( + value = "dagger.internal.codegen.ComponentProcessor", + comments = "https://dagger.dev" +) +@SuppressWarnings({ + "unchecked", + "rawtypes" +}) +public final class DaggerModule1_GetStringFactory implements Factory { + @Override + public String get() { + return getString(); + } + + public static DaggerModule1_GetStringFactory create() { + return InstanceHolder.INSTANCE; + } + + public static String getString() { + return Preconditions.checkNotNull(DaggerModule1.INSTANCE.getString(), "Cannot return null from a non-@Nullable @Provides method"); + } + + private static final class InstanceHolder { + private static final DaggerModule1_GetStringFactory INSTANCE = new DaggerModule1_GetStringFactory(); + } +} + */ + compile( + """ + package com.squareup.test + + import dagger.Binds + import dagger.Module + import dagger.Provides + + @Module + abstract class DaggerModule1 { + @Binds abstract fun bindString(string: String): CharSequence + + companion object { + @get:Provides val string: String = "abc" + } + } + """ + ) { + val factoryClass = daggerModule1.moduleFactoryClass("getString", companion = true) + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).isEmpty() + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val providedString = staticMethods.single { it.name == "getString" } + .invoke(null) as String + + assertThat(providedString).isEqualTo("abc") + assertThat((factoryInstance as Factory).get()).isEqualTo("abc") + } + } + + @Test fun `a factory class is generated for provided nullable properties`() { + /* +package com.squareup.test; + +import dagger.internal.Factory; +import javax.annotation.processing.Generated; +import org.jetbrains.annotations.Nullable; + +@Generated( + value = "dagger.internal.codegen.ComponentProcessor", + comments = "https://dagger.dev" +) +@SuppressWarnings({ + "unchecked", + "rawtypes" +}) +public final class DaggerModule1_GetStringFactory implements Factory { + private final DaggerModule1 module; + + public DaggerModule1_GetStringFactory(DaggerModule1 module) { + this.module = module; + } + + @Override + @Nullable + public String get() { + return getString(module); + } + + public static DaggerModule1_GetStringFactory create(DaggerModule1 module) { + return new DaggerModule1_GetStringFactory(module); + } + + @Nullable + public static String getString(DaggerModule1 instance) { + return instance.getString(); + } +} + */ + compile( + """ + package com.squareup.test + + import dagger.Module + import dagger.Provides + + @Module + class DaggerModule1 { + @get:Provides val string: String? = null + } + """ + ) { + val factoryClass = daggerModule1.moduleFactoryClass("getString") + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).containsExactly(daggerModule1) + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + + val module = daggerModule1.newInstance() + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null, module) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val providedString = staticMethods.single { it.name == "getString" } + .invoke(null, module) as String? + + assertThat(providedString).isNull() + assertThat((factoryInstance as Factory).get()).isNull() + } + } + + @Test fun `a factory class is generated for provided nullable properties in an object module`() { + /* +package com.squareup.test; + +import dagger.internal.Factory; +import javax.annotation.processing.Generated; +import org.jetbrains.annotations.Nullable; + +@Generated( + value = "dagger.internal.codegen.ComponentProcessor", + comments = "https://dagger.dev" +) +@SuppressWarnings({ + "unchecked", + "rawtypes" +}) +public final class DaggerModule1_GetStringFactory implements Factory { + @Override + @Nullable + public String get() { + return getString(); + } + + public static DaggerModule1_GetStringFactory create() { + return InstanceHolder.INSTANCE; + } + + @Nullable + public static String getString() { + return DaggerModule1.INSTANCE.getString(); + } + + private static final class InstanceHolder { + private static final DaggerModule1_GetStringFactory INSTANCE = new DaggerModule1_GetStringFactory(); + } +} + */ + compile( + """ + package com.squareup.test + + import dagger.Module + import dagger.Provides + + @Module + object DaggerModule1 { + @get:Provides val string: String? = null + } + """ + ) { + val factoryClass = daggerModule1.moduleFactoryClass("getString") + + val constructor = factoryClass.declaredConstructors.single() + assertThat(constructor.parameterTypes.toList()).isEmpty() + + val staticMethods = factoryClass.declaredMethods.filter { it.isStatic } + + val factoryInstance = staticMethods.single { it.name == "create" } + .invoke(null) + assertThat(factoryInstance::class.java).isEqualTo(factoryClass) + + val providedString = staticMethods.single { it.name == "getString" } + .invoke(null) as String? + + assertThat(providedString).isNull() + assertThat((factoryInstance as Factory).get()).isNull() + } + } + private fun compile( vararg sources: String, block: Result.() -> Unit = { }