Skip to content

Commit

Permalink
[python-package] simplify Dataset processing of label (#5456)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Sep 9, 2022
1 parent 1444a74 commit 2e9848c
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
_DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_LabelType = Union[
list,
np.ndarray,
pd_Series,
pd_DataFrame
]

ZERO_THRESHOLD = 1e-35

Expand Down Expand Up @@ -605,15 +611,6 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
return data, feature_name, categorical_feature, pandas_categorical


def _label_from_pandas(label):
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label = np.ravel(label.values.astype(np.float32, copy=False))
return label


def _dump_pandas_categorical(pandas_categorical, file_name=None):
categorical_json = json.dumps(pandas_categorical, default=json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n'
Expand Down Expand Up @@ -1200,7 +1197,7 @@ class Dataset:
def __init__(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
Expand Down Expand Up @@ -1505,7 +1502,7 @@ def _set_init_score_by_predictor(self, predictor, data, used_indices=None):
def _lazy_init(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
Expand All @@ -1525,7 +1522,6 @@ def _lazy_init(
feature_name,
categorical_feature,
self.pandas_categorical)
label = _label_from_pandas(label)

# process for args
params = {} if params is None else params
Expand Down Expand Up @@ -1936,7 +1932,7 @@ def construct(self) -> "Dataset":
def create_valid(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
weight=None,
group=None,
init_score=None,
Expand Down Expand Up @@ -2276,7 +2272,7 @@ def set_feature_name(self, feature_name: List[str]) -> "Dataset":
ctypes.c_int(len(feature_name))))
return self

def set_label(self, label) -> "Dataset":
def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
"""Set label of Dataset.
Parameters
Expand All @@ -2291,8 +2287,14 @@ def set_label(self, label) -> "Dataset":
"""
self.label = label
if self.handle is not None:
label = list_to_1d_numpy(_label_from_pandas(label), name='label')
self.set_field('label', label)
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label_array = np.ravel(label.values.astype(np.float32, copy=False))
else:
label_array = list_to_1d_numpy(label, name='label')
self.set_field('label', label_array)
self.label = self.get_field('label') # original values can be modified at cpp side
return self

Expand Down

0 comments on commit 2e9848c

Please sign in to comment.