-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
18dbde8
commit 6cc6eb4
Showing
3 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Pipelines | ||
map parameters spaces to execution artifacts | ||
""" | ||
import copy | ||
import abc | ||
from typing import Optional, Sequence, Any | ||
from autokoopman.core.hyperparameter import ParameterSpace | ||
from autokoopman.core.trajectory import TrajectoriesData | ||
|
||
|
||
class Pipeline: | ||
this_parameter_space = ParameterSpace("", []) | ||
|
||
def __init__(self, name: str) -> None: | ||
self._param_space = self.this_parameter_space | ||
self._next_stages: Sequence[Pipeline] = [] | ||
self.name = name | ||
|
||
@abc.abstractmethod | ||
def execute(self, inputs, params: Optional[Sequence[Any]]) -> Any: | ||
"""execute this stage only""" | ||
raise NotImplementedError | ||
|
||
def run(self, inputs, params: Optional[Sequence[Any]]): | ||
"""run full pipeline""" | ||
# input checks | ||
assert params in self.parameter_space | ||
params = self._split_inputs(params) | ||
|
||
# run the current stage | ||
results = self.execute(inputs, params[0]) | ||
|
||
# if no other stages, return the results, else flow them through the following stages | ||
if len(self._next_stages) == 0: | ||
return results | ||
else: | ||
rem = [ | ||
stage.run(results, p) for stage, p in zip(self._next_stages, params[1:]) | ||
] | ||
if len(rem) == 1: | ||
return rem[0] | ||
else: | ||
return tuple(rem) | ||
|
||
def add_post_pipeline(self, next_pipeline: "Pipeline"): | ||
assert isinstance( | ||
next_pipeline, Pipeline | ||
), f"next pipeline must be a Pipeline object" | ||
self._next_stages.append(next_pipeline) | ||
|
||
def __or__(self, next: Any): | ||
if not isinstance(next, Pipeline): | ||
raise ValueError(f"{next} must be a Pipeline") | ||
|
||
# create a new instance | ||
n = copy.deepcopy(self) | ||
n.add_post_pipeline(next) | ||
return n | ||
|
||
def _split_inputs(self, inputs: Sequence[Any]): | ||
idx = self._param_space.dimension | ||
inps = [inputs[0:idx]] | ||
for stage in self._next_stages: | ||
inps.append(inputs[idx : (idx + stage.parameter_space.dimension)]) | ||
idx += stage.parameter_space.dimension | ||
return inps | ||
|
||
@property | ||
def parameter_space(self): | ||
return ParameterSpace.from_parameter_spaces( | ||
self.name, | ||
[ | ||
self._param_space, | ||
*[stage.parameter_space for stage in self._next_stages], | ||
], | ||
) | ||
|
||
|
||
class TrajectoryPreprocessor(Pipeline): | ||
def run(self, inputs, params: Optional[Sequence[Any]]) -> TrajectoriesData: | ||
return super().run(inputs, params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import autokoopman.core.pipeline as apipe | ||
import autokoopman.core.hyperparameter as ahyp | ||
|
||
|
||
def test_compositionality(): | ||
"""test that the compositionality works""" | ||
|
||
class Succ(apipe.Pipeline): | ||
def execute(self, x, _): | ||
return x + 1 | ||
|
||
class Source(apipe.Pipeline): | ||
def execute(self, x, __): | ||
return x | ||
|
||
class Identity(apipe.Pipeline): | ||
def execute(self, x, _): | ||
return x | ||
|
||
head = Source("head") | ||
fork1 = Succ("fork1") | ||
|
||
fork1 |= Identity("stage3") | ||
fork1 |= Succ("stage4") | ||
|
||
head |= Succ("stage2") | ||
head |= fork1 | ||
|
||
assert head.run(0, []) == (1, (1, 2)) | ||
assert head.run(5, []) == (6, (6, 7)) |