Skip to content

Commit

Permalink
Refactor Feature chaining and add test case for Feature chain with la…
Browse files Browse the repository at this point in the history
…mbda (#206)
  • Loading branch information
BenjaminMidtvedt authored Aug 26, 2024
1 parent 35891bd commit 10b564b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deeptrack/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def __iter__(self):
def __next__(self):
yield self.update().resolve()

def __rshift__(self, other: "Feature") -> "Feature":
def __rshift__(self, other) -> "Feature":

# Allows chaining of features. For example,
# feature1 >> feature2 >> feature3
Expand All @@ -519,12 +519,11 @@ def __rshift__(self, other: "Feature") -> "Feature":
return Chain(self, other)

# Import here to avoid circular import.
from . import models


# If other is a function, call it on the output of the feature.
# For example, feature >> some_function
if isinstance(other, models.KerasModel):
return NotImplemented

if callable(other):
return self >> Lambda(lambda: other)

Expand Down
12 changes: 12 additions & 0 deletions deeptrack/test/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,18 @@ def test_Feature_arithmetic(self):
input_2 = [10, 20]
self.assertListEqual(pipeline(input_2), [-input_2[0], -input_2[1]])

def test_Features_chain_lambda(self):

value = features.Value(value=1)
func = lambda x: x + 1

feature = value >> func
feature.store_properties()

feature.update()
output_image = feature()
self.assertEqual(output_image, 2)

def test_Feature_repeat(self):
feature = features.Value(value=0) >> (features.Add(1) ^ iter(range(10)))

Expand Down

0 comments on commit 10b564b

Please sign in to comment.