diff --git a/batchflow/sampler.py b/batchflow/sampler.py index 59413d59c..f7a779c9a 100644 --- a/batchflow/sampler.py +++ b/batchflow/sampler.py @@ -151,8 +151,7 @@ def __and__(self, other): result of the multiplication. """ if isinstance(other, (float, int)): - self.weight *= other - return self + return WeightedSampler(self, self.weight * other) return AndSampler(self, other) @@ -346,6 +345,20 @@ def sample(self, size): return np.concatenate(samples)[:size] +class WeightedSampler(Sampler): + """ Class for implementing `&` (weighting) operation on a number and a `Sampler` instance. + """ + def __init__(self, base, weight, *args, **kwargs): + super().__init__(*args, **kwargs) + self.bases = [base] + self.weight = weight + + def sample(self, size): + """ Sampling procedure of a product of the number and the sampler instance. Check out the docstring of + `Sampler.sample` for more info. + """ + return self.bases[0].sample(size) + class BaseOperationSampler(Sampler): """ Base class for implementing all arithmetic operations on `Sampler`-instances.