diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ba704a06c360..4c32add1d3e6 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -23,6 +23,12 @@ _DatasetHandle = ctypes.c_void_p _LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] +_LGBM_LabelType = Union[ + list, + np.ndarray, + pd_Series, + pd_DataFrame +] ZERO_THRESHOLD = 1e-35 @@ -605,15 +611,6 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica return data, feature_name, categorical_feature, pandas_categorical -def _label_from_pandas(label): - if isinstance(label, pd_DataFrame): - if len(label.columns) > 1: - raise ValueError('DataFrame for label cannot have multiple columns') - _check_for_bad_pandas_dtypes(label.dtypes) - label = np.ravel(label.values.astype(np.float32, copy=False)) - return label - - def _dump_pandas_categorical(pandas_categorical, file_name=None): categorical_json = json.dumps(pandas_categorical, default=json_default_with_numpy) pandas_str = f'\npandas_categorical:{categorical_json}\n' @@ -1200,7 +1197,7 @@ class Dataset: def __init__( self, data, - label=None, + label: Optional[_LGBM_LabelType] = None, reference: Optional["Dataset"] = None, weight=None, group=None, @@ -1505,7 +1502,7 @@ def _set_init_score_by_predictor(self, predictor, data, used_indices=None): def _lazy_init( self, data, - label=None, + label: Optional[_LGBM_LabelType] = None, reference: Optional["Dataset"] = None, weight=None, group=None, @@ -1525,7 +1522,6 @@ def _lazy_init( feature_name, categorical_feature, self.pandas_categorical) - label = _label_from_pandas(label) # process for args params = {} if params is None else params @@ -1936,7 +1932,7 @@ def construct(self) -> "Dataset": def create_valid( self, data, - label=None, + label: Optional[_LGBM_LabelType] = None, weight=None, group=None, init_score=None, @@ -2276,7 +2272,7 @@ def set_feature_name(self, feature_name: List[str]) -> "Dataset": ctypes.c_int(len(feature_name)))) return self - def set_label(self, label) -> "Dataset": + def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": """Set label of Dataset. Parameters @@ -2291,8 +2287,14 @@ def set_label(self, label) -> "Dataset": """ self.label = label if self.handle is not None: - label = list_to_1d_numpy(_label_from_pandas(label), name='label') - self.set_field('label', label) + if isinstance(label, pd_DataFrame): + if len(label.columns) > 1: + raise ValueError('DataFrame for label cannot have multiple columns') + _check_for_bad_pandas_dtypes(label.dtypes) + label_array = np.ravel(label.values.astype(np.float32, copy=False)) + else: + label_array = list_to_1d_numpy(label, name='label') + self.set_field('label', label_array) self.label = self.get_field('label') # original values can be modified at cpp side return self