Skip to content

Commit b22762c

Browse files
authored
fix(compiler): Recursively find abilities [LNG-338] (#1086)
1 parent df5eb29 commit b22762c

File tree

5 files changed

+129
-37
lines changed

5 files changed

+129
-37
lines changed

aqua-src/antithesis.aqua

+21-23
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
aqua A
22

3-
import "aqua-src/gen/OneMore.aqua"
3+
export haveFun
44

5-
export main
5+
ability Compute:
6+
job() -> string
67

7-
alias SomeAlias: string
8+
func lift() -> Compute:
9+
job = () -> string:
10+
<- "job done"
11+
<- Compute(job)
812

9-
data NestedStruct:
10-
a: SomeAlias
13+
ability Function:
14+
run() -> string
1115

12-
data SomeStruct:
13-
al: SomeAlias
14-
nested: NestedStruct
16+
func roundtrip{Function}() -> string:
17+
res <- Function.run()
18+
<- res
1519

16-
ability SomeAbility:
17-
someStr: SomeStruct
18-
nested: NestedStruct
19-
al: SomeAlias
20-
someFunc(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct, SomeStruct, SomeAlias
20+
func disjoint_run{Compute}() -> Function:
21+
run = func () -> string:
22+
<- Compute.job()
23+
<- Function(run = run)
2124

22-
service Srv("a"):
23-
check(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct
24-
check2() -> SomeStruct
25-
check3() -> SomeAlias
26-
27-
func withAb{SomeAbility}() -> SomeStruct:
28-
<- SomeAbility.someStr
29-
30-
func main(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> string:
31-
<- ""
25+
func haveFun() -> string:
26+
comp = lift()
27+
fn = disjoint_run{comp}()
28+
res <- roundtrip{fn}()
29+
<- res

integration-tests/aqua/examples/abilitiesClosure.aqua

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
aqua M
22

3-
export bugLNG314
3+
export bugLNG314, bugLNG338
44

55
ability WorkerJob:
66
runOnSingleWorker(w: string) -> string
@@ -20,4 +20,30 @@ func bugLNG314() -> string:
2020
worker_job = WorkerJob(runOnSingleWorker = job2)
2121
subnet_job <- disjoint_run{worker_job}()
2222
res <- runJob(subnet_job)
23-
<- res
23+
<- res
24+
25+
ability Compute:
26+
job() -> string
27+
28+
func lift() -> Compute:
29+
job = () -> string:
30+
<- "job done"
31+
<- Compute(job)
32+
33+
ability Function:
34+
run() -> string
35+
36+
func roundtrip{Function}() -> string:
37+
res <- Function.run()
38+
<- res
39+
40+
func disj{Compute}() -> Function:
41+
run = func () -> string:
42+
<- Compute.job()
43+
<- Function(run = run)
44+
45+
func bugLNG338() -> string:
46+
comp = lift()
47+
fn = disj{comp}()
48+
res <- roundtrip{fn}()
49+
<- res

integration-tests/src/__test__/examples.spec.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import {
4040
multipleAbilityWithClosureCall,
4141
returnSrvAsAbilityCall,
4242
} from "../examples/abilityCall.js";
43-
import { bugLNG314Call } from "../examples/abilityClosureCall.js";
43+
import { bugLNG314Call, bugLNG338Call } from "../examples/abilityClosureCall.js";
4444
import {
4545
nilLengthCall,
4646
nilLiteralCall,
@@ -665,6 +665,11 @@ describe("Testing examples", () => {
665665
expect(result).toEqual("strstrstr");
666666
});
667667

668+
it("abilitiesClosure.aqua bug LNG-338", async () => {
669+
let result = await bugLNG338Call();
670+
expect(result).toEqual("job done");
671+
});
672+
668673
it("functors.aqua LNG-119 bug", async () => {
669674
let result = await bugLng119Call();
670675
expect(result).toEqual([1]);
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import {
2-
bugLNG314
2+
bugLNG314, bugLNG338
33
} from "../compiled/examples/abilitiesClosure.js";
44

55
export async function bugLNG314Call(): Promise<string> {
66
return await bugLNG314();
77
}
8+
9+
export async function bugLNG338Call(): Promise<string> {
10+
return await bugLNG338();
11+
}

model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala

+69-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import cats.syntax.option.*
1717
import cats.syntax.semigroup.*
1818
import cats.syntax.traverse.*
1919
import cats.{Eval, Monoid}
20+
import scala.annotation.tailrec
2021
import scribe.Logging
2122

2223
/**
@@ -82,6 +83,62 @@ object ArrowInliner extends Logging {
8283
arrowsToSave: Map[String, FuncArrow]
8384
)
8485

86+
/**
87+
* Find abilities recursively, because ability can hold arrow with another ability in it.
88+
* @param abilitiesToGather gather all fields for these abilities
89+
* @param varsFromAbs already gathered variables
90+
* @param arrowsFromAbs already gathered arrows
91+
* @param processedAbs already processed abilities
92+
* @return all needed variables and arrows
93+
*/
94+
@tailrec
95+
private def arrowsAndVarsFromAbilities(
96+
abilitiesToGather: Map[String, GeneralAbilityType],
97+
exports: Map[String, ValueModel],
98+
arrows: Map[String, FuncArrow],
99+
varsFromAbs: Map[String, ValueModel] = Map.empty,
100+
arrowsFromAbs: Map[String, FuncArrow] = Map.empty,
101+
processedAbs: Set[String] = Set.empty
102+
): (Map[String, ValueModel], Map[String, FuncArrow]) = {
103+
val varsFromAbilities = abilitiesToGather.flatMap { case (name, at) =>
104+
getAbilityVars(name, None, at, exports)
105+
}
106+
val arrowsFromAbilities = abilitiesToGather.flatMap { case (name, at) =>
107+
getAbilityArrows(name, None, at, exports, arrows)
108+
}
109+
110+
val allProcessed = abilitiesToGather.keySet ++ processedAbs
111+
112+
// find all names that is used in arrows
113+
val namesUsage = arrowsFromAbilities.values.flatMap(_.body.usesVarNames.value).toSet
114+
115+
// check if there is abilities that we didn't gather
116+
val abilitiesUsage = namesUsage.toList
117+
.flatMap(exports.get)
118+
.collect {
119+
case ValueModel.Ability(vm, at) if !allProcessed.contains(vm.name) =>
120+
vm.name -> at
121+
}
122+
.toMap
123+
124+
val allVars = varsFromAbilities ++ varsFromAbs
125+
val allArrows = arrowsFromAbilities ++ arrowsFromAbs
126+
127+
if (abilitiesUsage.isEmpty) {
128+
(allVars, allArrows)
129+
} else {
130+
arrowsAndVarsFromAbilities(
131+
abilitiesUsage,
132+
exports,
133+
arrows,
134+
allVars,
135+
allArrows,
136+
allProcessed
137+
)
138+
}
139+
140+
}
141+
85142
// Apply a callable function, get its fully resolved body & optional value, if any
86143
private def inline[S: Mangler: Arrows: Exports](
87144
fn: FuncArrow,
@@ -104,15 +161,15 @@ object ArrowInliner extends Logging {
104161
exports <- Exports[S].exports
105162
arrows <- Arrows[S].arrows
106163
// gather all arrows and variables from abilities
107-
returnedAbilities = rets.collect { case ValueModel.Ability(vm, at) =>
164+
abilitiesToGather = rets.collect { case ValueModel.Ability(vm, at) =>
108165
vm.name -> at
109166
}
110-
varsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
111-
getAbilityVars(name, None, at, exports)
112-
}.toMap
113-
arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
114-
getAbilityArrows(name, None, at, exports, arrows)
115-
}.toMap
167+
arrsVars = arrowsAndVarsFromAbilities(
168+
abilitiesToGather.toMap,
169+
exports,
170+
arrows
171+
)
172+
(varsFromAbilities, arrowsFromAbilities) = arrsVars
116173

117174
// find and get resolved arrows if we return them from the function
118175
returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet
@@ -172,9 +229,11 @@ object ArrowInliner extends Logging {
172229
abilityType,
173230
exports
174231
)
232+
val abilityExport =
233+
exports.get(abilityName).map(vm => abilityNewName.getOrElse(abilityName) -> vm).toMap
175234

176-
get(_.variables) ++ get(_.arrows).flatMap {
177-
case arrow @ (_, vm @ ValueModel.Arrow(_, _)) =>
235+
abilityExport ++ get(_.variables) ++ get(_.arrows).flatMap {
236+
case arrow @ (_, ValueModel.Arrow(_, _)) =>
178237
arrow.some
179238
case (_, m) =>
180239
internalError(s"($m) cannot be an arrow")
@@ -497,7 +556,7 @@ object ArrowInliner extends Logging {
497556
exports <- Exports[S].exports
498557
streams <- getOutsideStreamNames
499558
arrows = passArrows ++ arrowsFromAbilities
500-
559+
501560
inlineResult <- Exports[S].scope(
502561
Arrows[S].scope(
503562
for {

0 commit comments

Comments
 (0)