Skip to content

Commit

Permalink
adding set_hold_out and get_hold_out convenience methods
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Mar 6, 2024
1 parent 4283647 commit 771b128
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/sageworks/core/artifacts/feature_set_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,32 @@ def create_training_view(self, id_column: str, hold_out_ids: list[str]):
# Execute the CREATE VIEW query
self.data_source.execute_statement(create_view_query)

def set_hold_out_ids(self, id_column: str, hold_out_ids: list[str]):
"""Set the hold out ids for this FeatureSet
Args:
id_column (str): The name of the id column in the output DataFrame.
hold_out_ids (list[str]): The list of hold out ids.
"""
self.create_training_view(id_column, hold_out_ids)

def get_hold_out_ids(self, id_column: str) -> list[str]:
"""Get the hold out ids for this FeatureSet
Args:
id_column (str): The name of the id column in the output DataFrame.
Returns:
list[str]: The list of hold out ids.
"""
training_view_table = self.get_training_view_table(create=False)
if training_view_table is not None:
query = f"SELECT {id_column} FROM {training_view_table} WHERE training = 0"
hold_out_ids = self.query(query)[id_column].tolist()
return hold_out_ids
else:
return []

def get_training_view_table(self, create: bool = True) -> Union[str, None]:
"""Get the name of the training view for this FeatureSet
Args:
Expand Down Expand Up @@ -684,12 +710,20 @@ def onboard(self) -> bool:
my_features.create_default_training_view()

# Test the hold out set functionality with ints
print("Setting hold out ids (ints)...")
print("Setting hold out ids...")
table = my_features.get_training_view_table()
df = my_features.query(f"SELECT id, name FROM {table}")
my_hold_out_ids = [id for id in df["id"] if id < 20]
my_features.create_training_view("id", my_hold_out_ids)

# Convenience methods to set and get the hold out ids
print("Setting hold out ids...")
my_features.set_hold_out_ids("id", my_hold_out_ids)
print("Getting hold out ids...")
hold_output = my_features.get_hold_out_ids("id")
print(hold_output)
assert set(hold_output) == set(my_hold_out_ids)

# Test the hold out set functionality with strings
print("Setting hold out ids (strings)...")
my_hold_out_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
Expand Down

0 comments on commit 771b128

Please sign in to comment.