Skip to content

Commit

Permalink
working on transformations one to many
Browse files Browse the repository at this point in the history
  • Loading branch information
ftomassetti committed May 30, 2023
1 parent 79f9c1f commit 39b910b
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,22 @@ open class ParseTreeToASTTransformer(
allowGenericNode: Boolean = true,
val source: Source? = null
) : ASTTransformer(issues, allowGenericNode) {

/**
* Performs the transformation of a node and, recursively, its descendants. In addition to the overridden method,
* it also assigns the parseTreeNode to the AST node so that it can keep track of its position.
* However, a node factory can override the parseTreeNode of the nodes it creates (but not the parent).
*/
override fun transform(source: Any?, parent: Node?): Node? {
val node = super.transform(source, parent)
override fun transform(source: Any?, parent: Node?): List<Node> {
val node = super.transform(source, parent) as Node?
if (node != null && source is ParserRuleContext) {
if (node.origin == null) {
node.withParseTreeNode(source, this.source)
} else if (node.position != null && node.source == null) {
node.position!!.source = this.source
}
}
return node
return if (node == null) emptyList() else listOf(node)
}

override fun getSource(node: Node, source: Any): Any {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,43 @@ annotation class Mapped(val path: String = "")
/**
* Factory that, given a tree node, will instantiate the corresponding transformed node.
*/
class NodeFactory<Source, Output : Node>(
val constructor: (Source, ASTTransformer, NodeFactory<Source, Output>) -> Output?,
val children: MutableMap<String, ChildNodeFactory<Source, *, *>?> = mutableMapOf(),
var finalizer: (Output) -> Unit = {},
var skipChildren: Boolean = false,
class NodeFactory<Source, Output : Node> {
var constructor: (Source, ASTTransformer, NodeFactory<Source, Output>) -> List<Output>
var children: MutableMap<String, ChildNodeFactory<Source, *, *>?> = mutableMapOf()
var finalizer: (Output) -> Unit = {}
var skipChildren: Boolean = false
var childrenSetAtConstruction: Boolean = false
) {

constructor(
constructor: (Source, ASTTransformer, NodeFactory<Source, Output>) -> List<Output>,
children: MutableMap<String, ChildNodeFactory<Source, *, *>?> = mutableMapOf(),
finalizer: (Output) -> Unit = {},
skipChildren: Boolean = false,
childrenSetAtConstruction: Boolean = false
) {
this.constructor = constructor
this.children = children
this.finalizer = finalizer
this.skipChildren = skipChildren
this.childrenSetAtConstruction = childrenSetAtConstruction
}

constructor(
singleConstructor: (Source, ASTTransformer, NodeFactory<Source, Output>) -> Output?,
children: MutableMap<String, ChildNodeFactory<Source, *, *>?> = mutableMapOf(),
finalizer: (Output) -> Unit = {},
skipChildren: Boolean = false,
childrenSetAtConstruction: Boolean = false
) {
this.constructor = { source, at, nf ->
val result = singleConstructor(source, at, nf)
if (result == null) emptyList() else listOf(result)
}
this.children = children
this.finalizer = finalizer
this.skipChildren = skipChildren
this.childrenSetAtConstruction = childrenSetAtConstruction
}

/**
* Specify how to convert a child. The value obtained from the conversion could either be used
Expand Down Expand Up @@ -231,6 +261,15 @@ open class ASTTransformer(
private val _knownClasses = mutableMapOf<String, MutableSet<KClass<*>>>()
val knownClasses: Map<String, Set<KClass<*>>> = _knownClasses

fun transformToNode(source: Any?, parent: Node? = null): Node? {
val result = transform(source, parent)
return when (result.size) {
0 -> null
1 -> result.first()
else -> throw IllegalStateException()
}
}

/**
* Performs the transformation of a node and, recursively, its descendants.
*/
Expand All @@ -243,21 +282,23 @@ open class ASTTransformer(
throw Error("Mapping error: received collection when value was expected")
}
val factory = getNodeFactory<Any, Node>(source::class as KClass<Any>)
val node: Node?
val nodes: List<Node>
if (factory != null) {
node = makeNode(factory, source, allowGenericNode = allowGenericNode)
if (node == null) {
nodes = makeNodes(factory, source, allowGenericNode = allowGenericNode)
if (nodes == null) {
return emptyList()
}
if (!factory.skipChildren && !factory.childrenSetAtConstruction) {
setChildren(factory, source, node)
nodes.forEach { node -> setChildren(factory, source, node) }
}
nodes.forEach { node ->
factory.finalizer(node)
node.parent = parent
}
factory.finalizer(node)
node.parent = parent
} else {
if (allowGenericNode) {
val origin = asOrigin(source)
node = GenericNode(parent).withOrigin(origin)
nodes = listOf(GenericNode(parent).withOrigin(origin))
issues.add(
Issue.semantic(
"Source node not mapped: ${source::class.qualifiedName}",
Expand All @@ -269,7 +310,7 @@ open class ASTTransformer(
throw IllegalStateException("Unable to translate node $source (class ${source.javaClass})")
}
}
return listOf(node)
return nodes
}

private fun setChildren(
Expand Down Expand Up @@ -326,24 +367,26 @@ open class ASTTransformer(
return source
}

protected open fun <S : Any, T : Node> makeNode(
protected open fun <S : Any, T : Node> makeNodes(
factory: NodeFactory<S, T>,
source: S,
allowGenericNode: Boolean = true
): Node? {
val node = try {
): List<Node> {
val nodes = try {
factory.constructor(source, this, factory)
} catch (e: Exception) {
if (allowGenericNode) {
GenericErrorNode(e)
listOf(GenericErrorNode(e))
} else {
throw e
}
}
if (node?.origin == null) {
node?.withOrigin(asOrigin(source))
nodes.forEach { node ->
if (node?.origin == null) {
node?.withOrigin(asOrigin(source))
}
}
return node
return nodes
}

protected open fun <S : Any, T : Node> getNodeFactory(kClass: KClass<S>): NodeFactory<S, T>? {
Expand Down Expand Up @@ -373,6 +416,15 @@ open class ASTTransformer(
return nodeFactory
}

fun <S : Any, T : Node> registerMultipleNodeFactory(
kclass: KClass<S>,
factory: (S, ASTTransformer, NodeFactory<S, T>) -> List<T>
): NodeFactory<S, T> {
val nodeFactory = NodeFactory(factory)
factories[kclass] = nodeFactory
return nodeFactory
}

fun <S : Any, T : Node> registerNodeFactory(
kclass: KClass<S>,
factory: (S, ASTTransformer) -> T?
Expand All @@ -385,6 +437,9 @@ open class ASTTransformer(
fun <S : Any, T : Node> registerNodeFactory(kclass: KClass<S>, factory: (S) -> T?): NodeFactory<S, T> =
registerNodeFactory(kclass) { input, _, _ -> factory(input) }

fun <S : Any, T : Node> registerMultipleNodeFactory(kclass: KClass<S>, factory: (S) -> List<T>): NodeFactory<S, T> =
registerMultipleNodeFactory(kclass) { input, _, _ -> factory(input) }

inline fun <reified S : Any, reified T : Node> registerNodeFactory(): NodeFactory<S, T> {
return registerNodeFactory(S::class, T::class)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ParseTreeToASTTransformerTest {
DisplayIntStatement(value = 456).withParseTreeNode(pt.statement(1))
)
).withParseTreeNode(pt)
val transformedCU = transformer.transform(pt)!!
val transformedCU = transformer.transform(pt).first()
assertASTsAreEqual(cu, transformedCU, considerPosition = true)
assertTrue { transformedCU.hasValidParents() }
assertNull(transformedCU.invalidPositions().firstOrNull())
Expand Down Expand Up @@ -76,7 +76,7 @@ class ParseTreeToASTTransformerTest {
val pt = parser.compilationUnit()

val transformer = ParseTreeToASTTransformer()
assertASTsAreEqual(GenericNode(), transformer.transform(pt)!!)
assertASTsAreEqual(GenericNode(), transformer.transformToNode(pt)!!)
}

@Test
Expand All @@ -97,7 +97,7 @@ class ParseTreeToASTTransformerTest {
DisplayIntStatement(value = 456)
)
)
val transformedCU = transformer.transform(pt)!!
val transformedCU = transformer.transformToNode(pt)!!
assertASTsAreEqual(cu, transformedCU, considerPosition = true)
assertTrue { transformedCU.hasValidParents() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ class ASTTransformerTest {
}
assertASTsAreEqual(
Mult(IntLiteral(7), IntLiteral(8)),
myTransformer.transform(GenericBinaryExpression(Operator.MULT, IntLiteral(7), IntLiteral(8)))!!
myTransformer.transformToNode(GenericBinaryExpression(Operator.MULT, IntLiteral(7), IntLiteral(8)))!!
)
assertASTsAreEqual(
Sum(IntLiteral(7), IntLiteral(8)),
myTransformer.transform(GenericBinaryExpression(Operator.PLUS, IntLiteral(7), IntLiteral(8)))!!
myTransformer.transformToNode(GenericBinaryExpression(Operator.PLUS, IntLiteral(7), IntLiteral(8)))!!
)
}

Expand All @@ -122,7 +122,7 @@ class ASTTransformerTest {
),
BLangIntLiteral(4)
),
myTransformer.transform(
myTransformer.transformToNode(
ALangMult(
ALangSum(
ALangIntLiteral(1),
Expand Down Expand Up @@ -171,7 +171,7 @@ class ASTTransformerTest {
TypedLiteral("1", Type.INT),
Type.INT
),
myTransformer.transform(
myTransformer.transformToNode(
TypedSum(
TypedLiteral("1", Type.INT),
TypedLiteral("1", Type.INT),
Expand All @@ -186,7 +186,7 @@ class ASTTransformerTest {
TypedLiteral("test", Type.STR),
Type.STR
),
myTransformer.transform(
myTransformer.transformToNode(
TypedConcat(
TypedLiteral("test", Type.STR),
TypedLiteral("test", Type.STR),
Expand All @@ -201,7 +201,7 @@ class ASTTransformerTest {
TypedLiteral("test", Type.STR),
null
),
myTransformer.transform(
myTransformer.transformToNode(
TypedSum(
TypedLiteral("1", Type.INT),
TypedLiteral("test", Type.STR),
Expand All @@ -223,7 +223,7 @@ class ASTTransformerTest {
TypedLiteral("test", Type.STR),
null
),
myTransformer.transform(
myTransformer.transformToNode(
TypedConcat(
TypedLiteral("1", Type.INT),
TypedLiteral("test", Type.STR),
Expand Down Expand Up @@ -285,7 +285,7 @@ class ASTTransformerTest {
fun testTransforingOneNodeToMany() {
val transformer = ASTTransformer()
transformer.registerNodeFactory(BarRoot::class, BazRoot::class)
.withChild(BarRoot::stmts, BazRoot::stmts)
//.withChild(BarRoot::stmts, BazRoot::stmts)
transformer.registerNodeFactory(BarStmt::class) { s ->
listOf(BazStmt("${s.desc}-1"), BazStmt("${s.desc}-2"))
}
Expand All @@ -296,15 +296,19 @@ class ASTTransformerTest {
BarStmt("b")
)
)
val transformed = transformer.transform(original) as BazRoot
val transformed = transformer.transformToNode(original) as BazRoot
assertTrue { transformed.hasValidParents() }
assertEquals(transformed.origin, original)
assertASTsAreEqual(BazRoot(mutableListOf(
BazStmt("a-1"),
BazStmt("a-2"),
BazStmt("b-1"),
BazStmt("b-2")
)))
assertASTsAreEqual(
BazRoot(
mutableListOf(
BazStmt("a-1"),
BazStmt("a-2"),
BazStmt("b-1"),
BazStmt("b-2")
)
)
)
}
}

Expand All @@ -313,4 +317,4 @@ data class BazRoot(var stmts: MutableList<BazStmt> = mutableListOf()) : Node()
data class BazStmt(val desc: String) : Node()

data class BarRoot(var stmts: MutableList<BarStmt> = mutableListOf()) : Node()
data class BarStmt(val desc: String) : Node()
data class BarStmt(val desc: String) : Node()

0 comments on commit 39b910b

Please sign in to comment.