diff --git a/src/treeffuser/samples.py b/src/treeffuser/samples.py index 216f0c4..73f88c6 100644 --- a/src/treeffuser/samples.py +++ b/src/treeffuser/samples.py @@ -4,8 +4,8 @@ from typing import Union import numpy as np -import sklearn from jaxtyping import Float +from sklearn.neighbors import KernelDensity from tqdm import tqdm @@ -118,7 +118,7 @@ def sample_kde( self, bandwidth: Union[float, Literal["scott", "silverman"]] = 1.0, verbose: bool = False, - ) -> List[sklearn.neighbors.KernelDensity]: + ) -> List[KernelDensity]: """ Compute the Kernel Density Estimate (KDE) for each `x`. Estimate: `KDE[Y | X = x]` for each `x` using Gaussian kernels from `sklearn.neighbors`. @@ -135,8 +135,8 @@ def sample_kde( Returns ------- - kdes : list of sklearn.neighbors.KernelDensity - A list of `sklearn.neighbors.KernelDensity` objects, one for each `x`. + kdes : list of KernelDensity + A list of `KernelDensity` objects, one for each `x`. """ kdes = [] for i in tqdm( @@ -148,9 +148,7 @@ def sample_kde( y_i = self._samples[:, i, None] else: y_i = self._samples[:, i, :] - kde = sklearn.neighbors.KernelDensity( - bandwidth=bandwidth, algorithm="auto", kernel="gaussian" - ) + kde = KernelDensity(bandwidth=bandwidth, algorithm="auto", kernel="gaussian") kde.fit(y_i) kdes.append(kde)