Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Aug 12, 2024
1 parent 244b894 commit 74c1e01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 0 additions & 2 deletions fedeca/algorithms/torch_webdisco_algo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implement webdisco algorithm with Torch."""
import copy
from copy import deepcopy
from math import sqrt
from pathlib import Path
from typing import Any, List, Optional, Union

Expand All @@ -11,7 +10,6 @@
from autograd import elementwise_grad
from autograd import numpy as anp
from lifelines.utils import StepSizer
from pandas.api.types import is_numeric_dtype
from scipy.linalg import norm
from scipy.linalg import solve as spsolve
from substrafl.algorithms.pytorch import weight_manager
Expand Down
17 changes: 11 additions & 6 deletions fedeca/utils/survival_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,9 @@ def robust_sandwich_variance_pooled(
model. The sandwich variance estimator is a robust estimator of the variance
which accounts for the lack of dependence between the samples due to the
introduction of weights for example.
Parameters
----------
X_norm : np.ndarray or torch.Tensor
Input feature matrix of shape (n_samples, n_features).
y : np.ndarray or torch.Tensor
Expand All @@ -1309,6 +1312,11 @@ def robust_sandwich_variance_pooled(
Weights associated with each sample, with shape (n_samples,)
scaled_variance_matrix : np.ndarray or torch.Tensor
Classical scaled variance of the Cox model estimator.
Returns
-------
np.ndarray
The robust sandwich variance estimator.
"""
n_samples, n_features = X_norm.shape

Expand Down Expand Up @@ -1357,8 +1365,7 @@ def robust_sandwich_variance_pooled(


def km_curve(t, n, d, tmax=5000):
"""Computes Kaplan-Meier (KM) curve based on unique event times, number of
individuals at risk and number of deaths.
"""Compute Kaplan-Meier (KM) curve.
This function is typically used in conjunction with
`compute_events_statistics`. Note that the variance is computed
Expand Down Expand Up @@ -1482,8 +1489,7 @@ def compute_events_statistics(times, events):


def aggregate_events_statistics(list_t_n_d):
"""Aggregates (sums) events statistics from different centers, returning a single
tuple with the same format.
"""Aggregate (sums) events statistics from different centers.
Parameters
----------
Expand Down Expand Up @@ -1514,8 +1520,7 @@ def aggregate_events_statistics(list_t_n_d):


def extend_events_to_common_grid(list_t_n_d, t_common):
"""Extends a list of heterogeneous times, number of people at risk and number of
death on a common grid.
"""Extend a list of heterogeneous times, number of people at risk on common grid.
This method is an internal utility for `aggregate_events_statistics`.
Expand Down

0 comments on commit 74c1e01

Please sign in to comment.