diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index e5f984d..792cc9e 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -146,10 +146,13 @@ def update_single(self, x: Tensor, z: int) -> "QuadraticFitter": return self - def editor(self) -> QuadraticEditor: + def editor(self, device: str | None = None) -> QuadraticEditor: """Quadratic editor for the concept.""" sigma = self.sigma_xx - return QuadraticEditor(self.mean_x, ot_map(sigma[:, None], sigma)) + device = device or sigma.device + + T = ot_map(sigma[:, None], sigma).to(device) + return QuadraticEditor(self.mean_x.to(device), T) @cached_property def eraser(self) -> QuadraticEraser: diff --git a/pyproject.toml b/pyproject.toml index 793ae94..ddec8ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ license = {text = "MIT License"} dependencies = [ "torch", ] -version = "0.2.3" +version = "0.2.4" [project.optional-dependencies] dev = [