From 14d2c5f2692d40a006bdf581d05cfd34c0f78387 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 25 Apr 2024 17:20:59 +0200 Subject: [PATCH] removing unnecessary code --- molpipeline/estimators/chemprop/component_wrapper.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index b2d41068..afab0bf6 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -3,7 +3,6 @@ import abc from typing import Any, Iterable, Self -import torch from chemprop.conf import DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM, DEFAULT_HIDDEN_DIM from chemprop.models.model import MPNN as _MPNN from chemprop.nn.agg import Aggregation @@ -21,7 +20,6 @@ ) from chemprop.nn.transforms import UnscaleTransform from chemprop.nn.utils import Activation, get_activation_function -from chemprop.utils.registry import Factory from sklearn.base import BaseEstimator from torch import Tensor, nn @@ -165,13 +163,6 @@ def __init__( output_transform : UnscaleTransform or None, optional (default=None) Transformations to apply to the output. None defaults to UnscaleTransform. """ - if criterion is None: - task_weights = torch.ones(n_tasks) if task_weights is None else task_weights - criterion = Factory.build( - self._T_default_criterion, - task_weights=task_weights, - threshold=threshold, - ) super().__init__( n_tasks=n_tasks, input_dim=input_dim,