diff --git a/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContextAware.kt b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContextAware.kt index cab56e9..987a1bd 100644 --- a/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContextAware.kt +++ b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContextAware.kt @@ -1,7 +1,9 @@ package software.amazon.lastmile.kotlin.inject.anvil +import com.google.devtools.ksp.KspExperimental import com.google.devtools.ksp.getDeclaredFunctions import com.google.devtools.ksp.getVisibility +import com.google.devtools.ksp.isAnnotationPresent import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.symbol.ClassKind @@ -13,6 +15,7 @@ import com.google.devtools.ksp.symbol.KSFile import com.google.devtools.ksp.symbol.KSFunctionDeclaration import com.google.devtools.ksp.symbol.KSNode import com.google.devtools.ksp.symbol.KSType +import com.google.devtools.ksp.symbol.KSValueArgument import com.google.devtools.ksp.symbol.KSValueParameter import com.google.devtools.ksp.symbol.Visibility import me.tatarka.inject.annotations.Qualifier @@ -164,6 +167,32 @@ interface ContextAware { return origin().parentDeclaration as KSClassDeclaration } + fun KSClassDeclaration.mapKeys(): List { + return annotations + .filter { it.isMapKey() } + .map { annotation -> + val argument = annotation.requireMapKeyArgument() + MapKeyAnnotation(argument = argument) + } + .toList() + } + + private fun KSAnnotation.isMapKey(): Boolean = isTypeAnnotatedWith(MapKey::class) + + @OptIn(KspExperimental::class) + private fun KSAnnotation.isTypeAnnotatedWith(clazz: KClass): Boolean { + return annotationType + .resolve() + .declaration + .isAnnotationPresent(clazz) + } + + private fun KSAnnotation.requireMapKeyArgument(): KSValueArgument { + return requireNotNull(arguments.singleOrNull(), this) { + "MapKey $this must have one argument." + } + } + fun KSClassDeclaration.findAnnotation(annotation: KClass): KSAnnotation = findAnnotations(annotation).single() diff --git a/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKeyAnnotation.kt b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKeyAnnotation.kt new file mode 100644 index 0000000..c39e1f1 --- /dev/null +++ b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKeyAnnotation.kt @@ -0,0 +1,22 @@ +package software.amazon.lastmile.kotlin.inject.anvil + +import com.google.devtools.ksp.symbol.KSValueArgument + +/** + * Represents the key under which the contributed + * element will be added to the multi-binding `Map`. + * + * ``` + * @MapKey + * annotation class MyMapKey(val value: String) + * + * @Inject + * @ContributesBinding(AppScope::class, multibinding = true) + * @MyMapKey("foo") + * class Impl : Base + * ``` + * Where `MyMapKey` would represent the "MapKeyAnnotation". + */ +data class MapKeyAnnotation( + val argument: KSValueArgument +) diff --git a/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Util.kt b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Util.kt index d30dea5..d1bd80c 100644 --- a/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Util.kt +++ b/compiler-utils/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Util.kt @@ -7,6 +7,10 @@ import com.google.devtools.ksp.symbol.KSDeclaration import com.google.devtools.ksp.symbol.KSValueArgument import com.squareup.kotlinpoet.Annotatable import com.squareup.kotlinpoet.AnnotationSpec +import com.squareup.kotlinpoet.ParameterizedTypeName +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.asTypeName import com.squareup.kotlinpoet.ksp.toClassName import software.amazon.lastmile.kotlin.inject.anvil.internal.Origin import java.util.Locale @@ -97,3 +101,9 @@ fun KSDeclaration.requireQualifiedName(contextAware: ContextAware): String = fun KClass<*>.requireQualifiedName(): String = requireNotNull(qualifiedName) { "Qualified name was null for $this" } + +fun pairTypeOf(vararg typeNames: TypeName): ParameterizedTypeName { + return Pair::class + .asTypeName() + .parameterizedBy(*typeNames) +} diff --git a/compiler-utils/src/testFixtures/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Asserts.kt b/compiler-utils/src/testFixtures/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Asserts.kt index 2abb7bb..1b58d0e 100644 --- a/compiler-utils/src/testFixtures/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Asserts.kt +++ b/compiler-utils/src/testFixtures/kotlin/software/amazon/lastmile/kotlin/inject/anvil/Asserts.kt @@ -15,6 +15,7 @@ import assertk.assertions.isEqualTo import com.tschuchort.compiletesting.KotlinCompilation.ExitCode import org.jetbrains.kotlin.compiler.plugin.ExperimentalCompilerApi import java.lang.reflect.AnnotatedElement +import java.lang.reflect.ParameterizedType import kotlin.reflect.KClass fun Assert.isAnnotatedWith(annotation: KClass<*>) { @@ -44,3 +45,16 @@ fun Assert.isError() { } }.isEqualTo(ExitCode.COMPILATION_ERROR) } + +fun Assert.isPairOf( + first: Class<*>, + second: Class<*>, +) { + transform { element -> + element.rawType + }.isEqualTo(Pair::class.java) + + transform { element -> + element.actualTypeArguments + }.isEqualTo(arrayOf(first, second)) +} diff --git a/compiler/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessor.kt b/compiler/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessor.kt index 9470329..fddc307 100644 --- a/compiler/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessor.kt +++ b/compiler/src/main/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessor.kt @@ -4,26 +4,36 @@ import com.google.devtools.ksp.processing.CodeGenerator import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.processing.SymbolProcessor +import com.google.devtools.ksp.symbol.ClassKind import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSDeclaration import com.google.devtools.ksp.symbol.KSType import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.ParameterSpec +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.STAR +import com.squareup.kotlinpoet.TypeName import com.squareup.kotlinpoet.TypeSpec +import com.squareup.kotlinpoet.asTypeName import com.squareup.kotlinpoet.ksp.addOriginatingKSFile import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeName import com.squareup.kotlinpoet.ksp.writeTo +import me.tatarka.inject.annotations.IntoMap import me.tatarka.inject.annotations.IntoSet import me.tatarka.inject.annotations.Provides import software.amazon.lastmile.kotlin.inject.anvil.ContextAware import software.amazon.lastmile.kotlin.inject.anvil.ContributesBinding import software.amazon.lastmile.kotlin.inject.anvil.LOOKUP_PACKAGE +import software.amazon.lastmile.kotlin.inject.anvil.MapKeyAnnotation import software.amazon.lastmile.kotlin.inject.anvil.addOriginAnnotation import software.amazon.lastmile.kotlin.inject.anvil.argumentOfTypeAt import software.amazon.lastmile.kotlin.inject.anvil.decapitalize +import software.amazon.lastmile.kotlin.inject.anvil.pairTypeOf import software.amazon.lastmile.kotlin.inject.anvil.requireQualifiedName import kotlin.reflect.KClass @@ -82,14 +92,31 @@ internal class ContributesBindingProcessor( val annotations = clazz.findAnnotationsAtLeastOne(ContributesBinding::class) checkNoDuplicateBoundTypes(clazz, annotations) + val mapKeys = clazz.mapKeys() + val boundTypes = annotations .map { - GeneratedFunction( - boundType = boundType(clazz, it), - multibinding = it.argumentOfTypeAt(this, "multibinding") ?: false, - ) + val boundType = boundType(clazz, it) + val multibinding = it.argumentOfTypeAt(this, "multibinding") ?: false + if (multibinding && mapKeys.isNotEmpty()) { + mapKeys.map { mapKey -> + GeneratedFunction( + boundType = boundType, + multibinding = true, + mapKey = mapKey, + ) + } + } else { + listOf( + GeneratedFunction( + boundType = boundType, + multibinding = multibinding, + ), + ) + } } - .distinctBy { it.bindingMethodReturnType.canonicalName + it.multibinding } + .flatten() + .distinctBy { it.bindingMethodReturnType.canonicalName + it.multibinding + it.mapKey } val fileSpec = FileSpec.builder(componentClassName) .addType( @@ -99,23 +126,13 @@ internal class ContributesBindingProcessor( .addOriginAnnotation(clazz) .addFunctions( boundTypes.map { function -> - val multibindingSuffix = if (function.multibinding) { - "Multibinding" - } else { - "" - } FunSpec .builder( "provide${clazz.innerClassNames()}" + function.bindingMethodReturnType.simpleName + - multibindingSuffix, + function.multiBindingSuffix, ) .addAnnotation(Provides::class) - .apply { - if (function.multibinding) { - addAnnotation(IntoSet::class) - } - } .apply { val parameterName = clazz.innerClassNames().decapitalize() addParameter( @@ -127,9 +144,31 @@ internal class ContributesBindingProcessor( .build(), ) - addStatement("return $parameterName") + when { + function.multibinding && function.mapKey != null -> { + addAnnotation(IntoMap::class) + val (format, value) = function.mapKey.value() + addStatement("return $format to $parameterName", value) + returns( + pairTypeOf( + function.mapKey.type(), + function.bindingMethodReturnType, + ), + ) + } + + function.multibinding -> { + addAnnotation(IntoSet::class) + addStatement("return $parameterName") + returns(function.bindingMethodReturnType) + } + + else -> { + addStatement("return $parameterName") + returns(function.bindingMethodReturnType) + } + } } - .returns(function.bindingMethodReturnType) .build() }, ) @@ -214,9 +253,131 @@ internal class ContributesBindingProcessor( private inner class GeneratedFunction( boundType: KSType, val multibinding: Boolean, + val mapKey: MapKeyAnnotation? = null, ) { val bindingMethodReturnType by lazy { boundType.toClassName() } + + val multiBindingSuffix = if (multibinding) { + val mapKeySuffix = mapKey + ?.multiBindingSuffix() + ?.let { "_$it" } ?: "" + "Multibinding$mapKeySuffix" + } else { + "" + } } + + private fun MapKeyAnnotation.type(): TypeName { + val type = when (val value = argument.value) { + is Byte -> Byte::class.asTypeName() + is Short -> Short::class.asTypeName() + is Int -> Int::class.asTypeName() + is Long -> Long::class.asTypeName() + is Float -> Float::class.asTypeName() + is Double -> Double::class.asTypeName() + is Char -> Char::class.asTypeName() + is String -> String::class.asTypeName() + is Boolean -> Boolean::class.asTypeName() + + is KSType -> KClass::class.asTypeName().parameterizedBy(STAR) + is KSClassDeclaration -> when (value.classKind) { + ClassKind.ENUM_CLASS -> value.toClassName() + ClassKind.ENUM_ENTRY -> (value.parent as? KSClassDeclaration)?.toClassName() + else -> null + } + + else -> null + } + return requireNotNull(type, argument) { + "The argument type could not be determined for " + + "${argument.name?.asString()} = ${argument.value}." + } + } + + private fun MapKeyAnnotation.value(): Pair { + val value = argument.value + + val format = when (value) { + is Byte, + is Short, + is Int, + is Long, + is Float, + is Double, + is Boolean, + -> "%L" + + is Char -> "'%L'" + is String -> "%S" + + is KSType -> "%T::class" + is KSClassDeclaration -> "%L" + + else -> { + val message = "The argument value could not be determined for " + + "${argument.name?.asString()} = ${argument.value}." + logger.error(message, argument) + throw IllegalArgumentException(message) + } + } + + val argValue = when (value) { + is Byte -> "$value.toByte()" + is Short -> "$value.toShort()" + is Float -> "$value.toFloat()" + + is Int, + is Long, + is Double, + is Boolean, + is Char, + is String, + -> value + + is KSType -> value.toTypeName() + is KSClassDeclaration -> value + + else -> { + val message = "The argument value could not be determined for " + + "${argument.name?.asString()} = ${argument.value}." + logger.error(message, argument) + throw IllegalArgumentException(message) + } + } + + return format to argValue + } + + private fun MapKeyAnnotation.multiBindingSuffix(): String { + return when (val value = argument.value) { + is Byte -> "${value}b" + is Short -> "${value}s" + + is Int, + is Long, + is Char, + is String, + is Boolean, + -> value.toString() + + is Float, + is Double, + -> value.toString().replace(".", "_") + + is KSType -> value.declaration.simpleName.asString() + is KSClassDeclaration -> value.safeRequiredQualifiedName + + else -> { + val message = "The argument value could not be determined for " + + "${argument.name?.asString()} = ${argument.value}." + logger.error(message, argument) + throw IllegalArgumentException(message) + } + } + } + + private val KSDeclaration.safeRequiredQualifiedName: String + get() = qualifiedName!!.asString().replace(".", "_") } diff --git a/compiler/src/test/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessorTest.kt b/compiler/src/test/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessorTest.kt index bcd52d1..8c2ca87 100644 --- a/compiler/src/test/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessorTest.kt +++ b/compiler/src/test/kotlin/software/amazon/lastmile/kotlin/inject/anvil/processor/ContributesBindingProcessorTest.kt @@ -8,6 +8,7 @@ import assertk.assertions.hasSize import assertk.assertions.isEqualTo import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.COMPILATION_ERROR +import me.tatarka.inject.annotations.IntoMap import me.tatarka.inject.annotations.IntoSet import me.tatarka.inject.annotations.Provides import org.jetbrains.kotlin.compiler.plugin.ExperimentalCompilerApi @@ -18,7 +19,10 @@ import software.amazon.lastmile.kotlin.inject.anvil.generatedComponent import software.amazon.lastmile.kotlin.inject.anvil.inner import software.amazon.lastmile.kotlin.inject.anvil.isAnnotatedWith import software.amazon.lastmile.kotlin.inject.anvil.isNotAnnotatedWith +import software.amazon.lastmile.kotlin.inject.anvil.isPairOf import software.amazon.lastmile.kotlin.inject.anvil.origin +import java.lang.reflect.Method +import java.lang.reflect.ParameterizedType class ContributesBindingProcessorTest { @@ -314,6 +318,150 @@ class ContributesBindingProcessorTest { } } + @Test + fun `it's an error when a map key does not have an argument`() { + compile( + """ + package software.amazon.test + + import software.amazon.lastmile.kotlin.inject.anvil.ContributesBinding + import software.amazon.lastmile.kotlin.inject.anvil.MapKey + import me.tatarka.inject.annotations.Inject + + interface Base + + @MapKey + annotation class MyMapKey + + @Inject + @ContributesBinding(Unit::class, multibinding = true) + @MyMapKey + class Impl : Base + """, + exitCode = COMPILATION_ERROR, + ) { + assertThat(messages).contains( + "MapKey @MyMapKey must have one argument.", + ) + } + } + + @Test + fun `map keys are repeatable`() { + compile( + """ + package software.amazon.test + + import software.amazon.lastmile.kotlin.inject.anvil.ContributesBinding + import software.amazon.lastmile.kotlin.inject.anvil.StringKey + import me.tatarka.inject.annotations.Inject + + interface Base + + @Inject + @ContributesBinding(Unit::class, multibinding = true) + @StringKey("foo") + @StringKey("bar") + class Impl : Base + """, + ) { + val generatedComponent = impl.generatedComponent + + assertThat(generatedComponent.packageName).isEqualTo(LOOKUP_PACKAGE) + assertThat(generatedComponent.origin).isEqualTo(impl) + + with(generatedComponent.declaredMethods.single { it.name == "provideImplBaseMultibinding_foo" }) { + assertThat(parameters.single().type).isEqualTo(impl) + assertThat(parameterizedReturnType).isPairOf(String::class.java, base) + assertThat(this).isAnnotatedWith(IntoMap::class) + } + + with(generatedComponent.declaredMethods.single { it.name == "provideImplBaseMultibinding_bar" }) { + assertThat(parameters.single().type).isEqualTo(impl) + assertThat(parameterizedReturnType).isPairOf(String::class.java, base) + assertThat(this).isAnnotatedWith(Provides::class) + assertThat(this).isAnnotatedWith(IntoMap::class) + } + } + } + + @Test + fun `a component interface is generated in the lookup package for a contributed map multibinding`() { + compile( + """ + package software.amazon.test + + import software.amazon.lastmile.kotlin.inject.anvil.ContributesBinding + import software.amazon.lastmile.kotlin.inject.anvil.StringKey + import me.tatarka.inject.annotations.Inject + + interface Base + + @Inject + @ContributesBinding(Unit::class, multibinding = true) + @StringKey("foo") + class Impl : Base + """, + ) { + val generatedComponent = impl.generatedComponent + + assertThat(generatedComponent.packageName).isEqualTo(LOOKUP_PACKAGE) + assertThat(generatedComponent.origin).isEqualTo(impl) + + val method = generatedComponent.declaredMethods.single() + assertThat(method.name).isEqualTo("provideImplBaseMultibinding_foo") + assertThat(method.parameters.single().type).isEqualTo(impl) + assertThat(method.parameterizedReturnType).isPairOf(String::class.java, base) + assertThat(method).isAnnotatedWith(Provides::class) + assertThat(method).isAnnotatedWith(IntoMap::class) + } + } + + @Test + fun `both binding and multibinding component interfaces can be generated in the lookup package for a contributed map multibinding`() { + compile( + """ + package software.amazon.test + + import software.amazon.lastmile.kotlin.inject.anvil.ContributesBinding + import software.amazon.lastmile.kotlin.inject.anvil.StringKey + import me.tatarka.inject.annotations.Inject + + interface Base + + @Inject + @ContributesBinding(Unit::class, multibinding = false) + @ContributesBinding(Unit::class, multibinding = true) + @StringKey("foo") + class Impl : Base + """, + ) { + val generatedComponent = impl.generatedComponent + + assertThat(generatedComponent.packageName).isEqualTo(LOOKUP_PACKAGE) + assertThat(generatedComponent.origin).isEqualTo(impl) + + assertThat(generatedComponent.declaredMethods).hasSize(2) + + val bindingMethod = generatedComponent.declaredMethods.first { + it.name == "provideImplBase" + } + assertThat(bindingMethod.parameters.single().type).isEqualTo(impl) + assertThat(bindingMethod.returnType).isEqualTo(base) + assertThat(bindingMethod).isAnnotatedWith(Provides::class) + assertThat(bindingMethod).isNotAnnotatedWith(IntoMap::class) + + val multibindingBindingMethod = generatedComponent.declaredMethods.first { + it.name == "provideImplBaseMultibinding_foo" + } + assertThat(multibindingBindingMethod.parameters.single().type).isEqualTo(impl) + assertThat(multibindingBindingMethod.parameterizedReturnType) + .isPairOf(String::class.java, base) + assertThat(multibindingBindingMethod).isAnnotatedWith(Provides::class) + assertThat(multibindingBindingMethod).isAnnotatedWith(IntoMap::class) + } + } + private val JvmCompilationResult.base: Class<*> get() = classLoader.loadClass("software.amazon.test.Base") @@ -325,4 +473,7 @@ class ContributesBindingProcessorTest { private val JvmCompilationResult.impl2: Class<*> get() = classLoader.loadClass("software.amazon.test.Impl2") + + private val Method.parameterizedReturnType: ParameterizedType + get() = (genericReturnType as ParameterizedType) } diff --git a/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContributesBinding.kt b/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContributesBinding.kt index c021ee2..7282e2d 100644 --- a/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContributesBinding.kt +++ b/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/ContributesBinding.kt @@ -64,6 +64,19 @@ import kotlin.reflect.KClass * @ContributesBinding(AppScope::class, boundType = Base2::class, multibinding = true) * class Impl : Base, Base2 * ``` + * + * If the class is annotated with a [MapKey] annotation, then the binding will be contributed + * to a multi-binding `Map` instead of a `Set`. + * + * ``` + * @MapKey + * annotation class MyMapKey(val value: String) + * + * @Inject + * @ContributesBinding(AppScope::class, multibinding = true) + * @MyMapKey("foo") + * class Impl : Base + * ``` */ @Target(CLASS) @Repeatable diff --git a/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKey.kt b/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKey.kt new file mode 100644 index 0000000..8ef8e15 --- /dev/null +++ b/runtime/src/commonMain/kotlin/software/amazon/lastmile/kotlin/inject/anvil/MapKey.kt @@ -0,0 +1,50 @@ +package software.amazon.lastmile.kotlin.inject.anvil + +import kotlin.reflect.KClass + +/** + * Marks an annotation class as a key for a multi-binding `Map`. + * + * When used with a [ContributesBinding] annotation, this annotation specifies the key under + * which the contributed element will be added to the map. + * + * ``` + * @MapKey + * annotation class MyMapKey(val value: String) + * + * @Inject + * @ContributesBinding(AppScope::class, multibinding = true) + * @MyMapKey("foo") + * class Impl : Base + * ``` + */ +@Target(AnnotationTarget.ANNOTATION_CLASS) +public annotation class MapKey + +@MapKey +@Target(AnnotationTarget.CLASS) +@Repeatable +public annotation class StringKey( + val value: String, +) + +@MapKey +@Target(AnnotationTarget.CLASS) +@Repeatable +public annotation class IntKey( + val value: Int, +) + +@MapKey +@Target(AnnotationTarget.CLASS) +@Repeatable +public annotation class LongKey( + val value: Long, +) + +@MapKey +@Target(AnnotationTarget.CLASS) +@Repeatable +public annotation class ClassKey( + val value: KClass<*>, +)