Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Enforce re-order for node label processing during classification #1136

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import logging
from typing import Any, Dict, Optional

from graphstorm_processing.constants import VALID_TASK_TYPES

thvasilo marked this conversation as resolved.
Show resolved Hide resolved

class LabelConfig(abc.ABC):
"""Basic class for label config"""
Expand Down Expand Up @@ -55,6 +57,9 @@ def __init__(self, config_dict: Dict[str, Any]):
self._mask_field_names = None

def _sanity_check(self):
assert (
self._task_type in VALID_TASK_TYPES
), f"Invalid task type {self._task_type}, must be one of {VALID_TASK_TYPES}"
if self._label_column == "":
assert self._task_type == "link_prediction", (
"When no label column is specified, the task type must be link_prediction, "
Expand Down Expand Up @@ -83,6 +88,25 @@ def _sanity_check(self):
assert all(isinstance(x, str) for x in self._mask_field_names)
assert len(self._mask_field_names) == 3

def __repr__(self) -> str:
"""Formal object representation for debugging"""
return (
f"{self.__class__.__name__}(label_column={self._label_column!r}, "
f"task_type={self._task_type!r}, separator={self._separator!r}, "
f"multilabel={self._multilabel!r}, split={self._split!r}, "
f"custom_split_filenames={self._custom_split_filenames!r}, "
f"mask_field_names={self._mask_field_names!r})"
)

def __str__(self) -> str:
"""Informal object representation for readability"""
task_desc = f"{self._task_type} task"
if self._label_column:
task_desc += f" on column '{self._label_column}'"
if self._multilabel:
task_desc += f" (multilabel with separator '{self._separator}')"
return task_desc

@property
def label_column(self) -> str:
"""The name of the column storing the target label property value."""
Expand Down
10 changes: 10 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@
NODE_MAPPING_STR = "orig"
NODE_MAPPING_INT = "new"

################# Reserved columns ################
DATA_SPLIT_SET_MASK_COL = "GSP-SAMPLE-SET-MASK"

################# Supported task types ##############
VALID_TASK_TYPES = {
"classification",
"regression",
"link_prediction",
}


################# Supported execution envs ##############
class ExecutionEnv(Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from dataclasses import dataclass
from math import fsum
from typing import Optional

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import FloatType
from pyspark.sql.types import NumericType

from graphstorm_processing.config.label_config_base import LabelConfig
from graphstorm_processing.data_transformations.dist_transformations import (
Expand Down Expand Up @@ -100,11 +101,33 @@ class DistLabelLoader:
The SparkSession to use for processing.
"""

def __init__(self, label_config: LabelConfig, spark: SparkSession) -> None:
def __init__(
self, label_config: LabelConfig, spark: SparkSession, order_col: Optional[str] = None
) -> None:
self.label_config = label_config
self.label_column = label_config.label_column
self.spark = spark
self.label_map: dict[str, int] = {}
self.order_col = order_col

def __str__(self) -> str:
"""Informal object representation for readability"""
return (
f"DistLabelLoader(label_column='{self.label_column}', "
f"task_type='{self.label_config.task_type}', "
f"multilabel={self.label_config.multilabel}, "
f"order_col={self.order_col!r})"
)

def __repr__(self) -> str:
"""Formal object representation for debugging"""
return (
f"DistLabelLoader("
f"label_config={self.label_config!r}, "
f"spark={self.spark!r}, "
f"order_col={self.order_col!r}, "
f"label_map={self.label_map!r})"
)

def process_label(self, input_df: DataFrame) -> DataFrame:
"""Transforms the label column in the input DataFrame to conform to GraphStorm expectations.
Expand Down Expand Up @@ -134,23 +157,32 @@ def process_label(self, input_df: DataFrame) -> DataFrame:
label_type = input_df.schema[self.label_column].dataType

if self.label_config.task_type == "classification":
assert self.order_col, f"{self.order_col} must be provided for classification tasks"
if self.label_config.multilabel:
assert self.label_config.separator
label_transformer = DistMultiLabelTransformation(
[self.label_config.label_column], self.label_config.separator
)
else:
label_transformer = DistSingleLabelTransformation(
[self.label_config.label_column], self.spark
[self.label_config.label_column],
self.spark,
)

transformed_label = label_transformer.apply(input_df).select(self.label_column)
transformed_label = label_transformer.apply(input_df)
if self.order_col:
assert self.order_col in transformed_label.columns, (
f"{self.order_col=} needs to be part of transformed "
f"label DF, got {transformed_label.columns=}"
)
transformed_label = transformed_label.sort(self.order_col).cache()

self.label_map = label_transformer.value_map
return transformed_label
elif self.label_config.task_type == "regression":
if not isinstance(label_type, FloatType):
if not isinstance(label_type, NumericType):
raise RuntimeError(
"Data type for regression should be FloatType, "
"Data type for regression should be a NumericType, "
f"got {label_type} for {self.label_column}"
)
return input_df.select(self.label_column)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,12 @@ def __init__(self, cols: Sequence[str], separator: str) -> None:
def get_transformation_name() -> str:
return "DistMultiCategoryTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
def apply(self, input_df: DataFrame, return_all_cols: bool = False) -> DataFrame:
col_datatype = input_df.schema[self.multi_column].dataType
if return_all_cols:
original_cols = {*input_df.columns} - {self.multi_column}
else:
original_cols = {}
is_array_col = False
if col_datatype.typeName() == "array":
assert isinstance(col_datatype, ArrayType)
Expand All @@ -326,13 +330,19 @@ def apply(self, input_df: DataFrame) -> DataFrame:

is_array_col = True

# Parquet input might come with arrays already, CSV will need splitting
if is_array_col:
list_df = input_df.select(self.multi_column).alias(self.multi_column)
multi_column = F.col(self.multi_column)
else:
list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
multi_column = F.split(F.col(self.multi_column), self.separator).alias(
self.multi_column
)

list_df = input_df.select(
multi_column,
*original_cols,
)

distinct_category_counts = (
list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column)))
.groupBy(SINGLE_CATEGORY_COL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, cols: Sequence[str], spark: SparkSession) -> None:

def apply(self, input_df: DataFrame) -> DataFrame:
assert self.spark
original_cols = {*input_df.columns} - {self.label_column}
processed_col_name = self.label_column + "_processed"

str_indexer = StringIndexer(
Expand All @@ -63,13 +64,15 @@ def apply(self, input_df: DataFrame) -> DataFrame:

# Labels that were missing and were assigned the value numLabels by the StringIndexer
# are converted to None
long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select(
long_class_label = indexed_df.select(
F.when(
F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore
F.lit(None),
)
.otherwise(F.col(self.label_column))
.alias(self.label_column)
.cast("long")
.alias(self.label_column),
*original_cols,
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
)

# Get a mapping from original label to encoded value
Expand Down Expand Up @@ -112,7 +115,7 @@ def __init__(self, cols: Sequence[str], separator: str) -> None:
super().__init__(cols, separator)
self.label_column = cols[0]

def apply(self, input_df: DataFrame) -> DataFrame:
multi_cat_df = super().apply(input_df)
def apply(self, input_df: DataFrame, return_all_cols=True) -> DataFrame:
multi_cat_df = super().apply(input_df, return_all_cols=return_all_cols)

return multi_cat_df
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)

self.loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
Expand All @@ -287,17 +288,18 @@ def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False
bucket, s3_prefix = s3_utils.extract_bucket_and_key(self.output_prefix)
s3 = boto3.resource("s3")

output_files = os.listdir(loader.output_path)
output_files = os.listdir(loader.local_meta_output_path)
for output_file in output_files:
s3.meta.client.upload_file(
f"{os.path.join(loader.output_path, output_file)}",
f"{os.path.join(loader.local_meta_output_path, output_file)}",
bucket,
f"{s3_prefix}/{output_file}",
)

def run(self) -> None:
"""
Executes the Spark processing job.
Executes the Spark processing job, optional repartition job, and uploads any metadata files
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
if needed.
"""
logging.info("Performing data processing with PySpark...")

Expand Down Expand Up @@ -355,7 +357,7 @@ def run(self) -> None:
# If any of the metadata modification took place, write an updated metadata file
if updated_metadata:
updated_meta_path = os.path.join(
self.loader.output_path, "updated_row_counts_metadata.json"
self.loader.local_meta_output_path, "updated_row_counts_metadata.json"
)
with open(
updated_meta_path,
Expand Down
Loading
Loading