diff --git a/deeptrack/features.py b/deeptrack/features.py index 6de9f04c..edbc36b2 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -508,7 +508,7 @@ def __iter__(self): def __next__(self): yield self.update().resolve() - def __rshift__(self, other: "Feature") -> "Feature": + def __rshift__(self, other) -> "Feature": # Allows chaining of features. For example, # feature1 >> feature2 >> feature3 @@ -519,12 +519,11 @@ def __rshift__(self, other: "Feature") -> "Feature": return Chain(self, other) # Import here to avoid circular import. - from . import models + # If other is a function, call it on the output of the feature. # For example, feature >> some_function - if isinstance(other, models.KerasModel): - return NotImplemented + if callable(other): return self >> Lambda(lambda: other) diff --git a/deeptrack/test/test_features.py b/deeptrack/test/test_features.py index c062084a..4418c74a 100644 --- a/deeptrack/test/test_features.py +++ b/deeptrack/test/test_features.py @@ -458,6 +458,18 @@ def test_Feature_arithmetic(self): input_2 = [10, 20] self.assertListEqual(pipeline(input_2), [-input_2[0], -input_2[1]]) + def test_Features_chain_lambda(self): + + value = features.Value(value=1) + func = lambda x: x + 1 + + feature = value >> func + feature.store_properties() + + feature.update() + output_image = feature() + self.assertEqual(output_image, 2) + def test_Feature_repeat(self): feature = features.Value(value=0) >> (features.Add(1) ^ iter(range(10)))