Skip to content

Commit

Permalink
support autowire two and more different beans of one type
Browse files Browse the repository at this point in the history
  • Loading branch information
tepa46 committed Jul 17, 2023
1 parent f44dd6f commit fb23875
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import org.utbot.framework.plugin.api.UtSpringContextModel
import org.utbot.framework.plugin.api.util.SpringModelUtils.getBeanNameOrNull
import org.utbot.framework.plugin.api.util.id
import java.lang.Exception
import java.util.Collections.max

abstract class CgAbstractSpringTestClassConstructor(context: CgContext):
abstract class CgAbstractSpringTestClassConstructor(context: CgContext) :
CgAbstractTestClassConstructor<SpringTestClassModel>(context) {

protected val variableConstructor: CgSpringVariableConstructor =
Expand Down Expand Up @@ -100,22 +101,48 @@ abstract class CgAbstractSpringTestClassConstructor(context: CgContext):

val constructedDeclarations = mutableListOf<CgFieldDeclaration>()
for ((classId, listOfUtModels) in groupedModelsByClassId) {
val modelWrapper = listOfUtModels.firstOrNull() ?: continue
val model = modelWrapper.model
val baseVarName = model.getBeanNameOrNull()

val createdVariable = variableConstructor.getOrCreateVariable(model, baseVarName) as? CgVariable
?: error("`UtCompositeModel` model was expected, but $model was found")
// group [listOfUtModels] by `testSetId` and `executionId`
// to check how many instance of one type used in each execution
val groupedListOfUtModel = listOfUtModels.groupByTo(HashMap()) {
Pair(
it.testSetId,
it.executionId,
)
}

val declaration = CgDeclaration(classId, variableName = createdVariable.name, initializer = null)
constructedDeclarations += CgFieldDeclaration(ownerClassId = currentTestClass, declaration, annotation)
// max count instances of one type in one execution
val instanceMaxCount = max(groupedListOfUtModel.map { (_, modelsList) -> modelsList.size })

listOfUtModels.forEach { key ->
valueByUtModelWrapper[key] = createdVariable
}
// if [instanceCount] is 1, then we mock variable by @Mock annotation
// Otherwise we will mock variable by simple mock later
if (instanceMaxCount == 1) {
val modelWrapper = listOfUtModels.firstOrNull() ?: continue
val model = modelWrapper.model

val baseVarName = model.getBeanNameOrNull()

val createdVariable = variableConstructor.getOrCreateVariable(model, baseVarName) as? CgVariable
?: error("`UtCompositeModel` model was expected, but $model was found")

variableConstructor.annotatedModelGroups
.getOrPut(annotationClassId) { mutableSetOf() } += listOfUtModels
val declaration = CgDeclaration(classId, variableName = createdVariable.name, initializer = null)

constructedDeclarations += CgFieldDeclaration(
ownerClassId = currentTestClass,
declaration,
annotation
)

groupedListOfUtModel
.forEach { (_, modelsList) ->
val currentModel = modelsList.firstOrNull()

currentModel?.let{
valueByUtModelWrapper[currentModel] = createdVariable
variableConstructor.annotatedModelGroups.getOrPut(annotationClassId) { mutableSetOf() } += currentModel
}
}
}
}

return constructedDeclarations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ abstract class CgClassFieldManagerImpl(context: CgContext) :
val variableConstructor: CgSpringVariableConstructor by lazy {
CgComponents.getVariableConstructorBy(context) as CgSpringVariableConstructor
}

fun findCgValueByModel(model: UtModel, setOfModels: Set<UtModelWrapper>?): CgValue? {
val key = setOfModels?.find { it == model.wrap() } ?: return null
return valueByUtModelWrapper[key]
}
}

class CgInjectingMocksFieldsManager(val context: CgContext) : CgClassFieldManagerImpl(context) {
Expand All @@ -43,12 +48,15 @@ class CgInjectingMocksFieldsManager(val context: CgContext) : CgClassFieldManage
}

modelFields?.forEach { (fieldId, fieldModel) ->
//creating variables for modelVariable fields
// creating variables for modelVariable fields
val variableForField = variableConstructor.getOrCreateVariable(fieldModel)

// If field model is a mock, it is set in the connected with instance under test automatically via @InjectMocks;
// is variable mocked by @Mock annotation
val isMocked = findCgValueByModel(fieldModel, variableConstructor.annotatedModelGroups[mockClassId]) != null

// If field model is a mock model and is mocked by @Mock annotation in classFields, it is set in the connected with instance under test automatically via @InjectMocks;
// Otherwise we need to set this field manually.
if (!fieldModel.isMockModel()) {
if (!fieldModel.isMockModel() || !isMocked) {
variableConstructor.setFieldValue(modelVariable, fieldId, variableForField)
}
}
Expand Down Expand Up @@ -97,12 +105,13 @@ class ClassFieldManagerFacade(context: CgContext) : CgContextOwner by context {

fun constructVariableForField(
model: UtModel,
annotatedModelGroups: Map<ClassId, Set<UtModelWrapper>>,
): CgValue? {
val annotationManagers = listOf(injectingMocksFieldsManager, mockedFieldsManager, autowiredFieldsManager)

annotationManagers.forEach { manager ->
val alreadyCreatedVariable = findCgValueByModel(model, annotatedModelGroups[manager.annotationType])
val annotatedModelGroups = manager.variableConstructor.annotatedModelGroups

val alreadyCreatedVariable = manager.findCgValueByModel(model, annotatedModelGroups[manager.annotationType])

if (alreadyCreatedVariable != null) {
return manager.constructVariableForField(model, alreadyCreatedVariable)
Expand All @@ -111,9 +120,4 @@ class ClassFieldManagerFacade(context: CgContext) : CgContextOwner by context {

return null
}

private fun findCgValueByModel(model: UtModel, setOfModels: Set<UtModelWrapper>?): CgValue? {
val key = setOfModels?.find { it == model.wrap() } ?: return null
return valueByUtModelWrapper[key]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CgSpringVariableConstructor(context: CgContext) : CgVariableConstructor(co
private val classFieldManager = ClassFieldManagerFacade(context)

override fun getOrCreateVariable(model: UtModel, name: String?): CgValue {
val variable = classFieldManager.constructVariableForField(model, annotatedModelGroups)
val variable = classFieldManager.constructVariableForField(model)

variable?.let { return it }

Expand Down

0 comments on commit fb23875

Please sign in to comment.