Skip to content

Commit

Permalink
update logistic regression model and metric calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Jun 5, 2024
1 parent 5fc689a commit 720fc4b
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 172 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from model_utils import expected_calibration_error
6 changes: 3 additions & 3 deletions calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def parse_cli_args():
default=run_id_args_default,
)

models_args_help = "The Pythia model to get the perplexities for. Valid options are: 70m, 160m, 410m, 1b, 1.4b, 2.8b, 6.9b, 12b"
models_args_help = "The Pythia model to get the metrics for. Valid options are: 70m, 160m, 410m, 1b, 1.4b, 2.8b, 6.9b, 12b, 12b.23000, 12b.43000, 12b.63000, 12b.83000, 12b.103000, 12b.123000"
models_args_default = ["70m", "160m", "410m", "1b", "1.4b", "2.8b", "6.9b", "12b"]
parser.add_argument(
"--models",
type=str,
help=models_args_help,
choices=models_args_default,
choices=models_args_default + ["12b.23000", "12b.43000", "12b.63000", "12b.83000", "12b.103000", "12b.123000"],
default=models_args_default,
)

Expand Down Expand Up @@ -254,7 +254,7 @@ def load_precomputed_features(
f"{scheme}_templates",
semantic_duplicates_map
))
else:
elif not model_name[-1].isdigit(): # we do not have results of semantic snowclones for intermediate checkpoints
hf_dataset_names.append((
PrecomputedFeatureName.SEMANTIC_SNOWCLONES,
f"usvsnsp/semantic-duplicates",
Expand Down
2 changes: 1 addition & 1 deletion filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
if not _has_registered_all_filters:
# The import here determines the order of the pipeline
from .detokenize import detokenize
from .pattern_incrementing import incrementing_sequences_filter
from .pattern import pattern_sequences_filter
from .highly_duplicated_filter import sequence_duplicates_filter
from .token_frequency_statistics_filter import token_frequency_statistics_filter
from .highly_repetitive import highly_repetitive_filter
Expand Down
6 changes: 3 additions & 3 deletions filters/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from typing import Any, Callable, Dict, List, TypeAlias
from typing import Any, Callable, Dict, List

from pyspark.sql import DataFrame, SparkSession

from filters.constants import PrecomputedFeatureName
from utils import initialize_logger
from spark.constants import NUM_OUTPUT_PARTITIONS, SPARK_CACHE_DIR

FilterFunc: TypeAlias = Callable[..., Any]
PrecomputedFeatures: TypeAlias = Dict[PrecomputedFeatureName, DataFrame]
FilterFunc = Callable[..., Any]
PrecomputedFeatures = Dict[PrecomputedFeatureName, DataFrame]

LOGGER: logging.Logger = initialize_logger()

Expand Down
44 changes: 17 additions & 27 deletions filters/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@

from .base import PIPELINE_SINGLETON

import re

import re

def find_if_incrementing_or_repeating(splits):
def find_if_incrementing_or_repeating(splits, test_repeating=False):
"""Finds if the given list of words is incrementing
A sequence is incrementing if it there are a set of integers or decimals in arithmetic progression
Expand All @@ -25,16 +21,16 @@ def find_if_incrementing_or_repeating(splits):
the sequence is incrementing and the calculated difference.
"""
# We need atleast 3 integers to define an AP
if len(splits) < 3:
if len(splits) < 3 and not test_repeating:
return False, 0
elif len(splits) == 3:
elif len(splits) == 3 and not test_repeating:
# Every element has to be a number
if not all([type(i) in [float, int] for i in splits]): return False, 0
return (splits[2] - 2*splits[1] + splits[0]) < 1e-5, splits[2] - splits[1]

# First and last words of a sequence can be partial
# We ignore them if length of splits is more than 4
if len(splits) > 4:
if len(splits) > 4 and not test_repeating:
splits = splits[1:-1]


Expand Down Expand Up @@ -79,7 +75,14 @@ def find_if_incrementing_or_repeating(splits):

return False, 0

def split_text(text, handle_decimals = True):
def split_text(text, split_type = "incrementing"):

if split_type == "repeating":
return list(text)

elif split_type != "incrementing":
raise ValueError("Invalid Split Type")

# Check if we have hexadecimal numerals
text = re.sub(r"\s+", " ", text)
splits = []
Expand Down Expand Up @@ -122,32 +125,19 @@ def split_text(text, handle_decimals = True):
try:
splits_new.append(int(word))
except ValueError:
# Handle decimals
if word == '.' and idx > 0 and (idx+1) < len(splits) and handle_decimals:
if splits[idx-1].isdigit() and splits[idx+1].isdigit():
try:
splits_new[-1] = float(splits[idx-1] + splits[idx] + splits[idx+1])
to_continue = True
except ValueError:
splits_new.append(word)
else:
splits_new.append(word)
splits_new.append(word)

return splits_new

def is_pattern(text):
splits = split_text(text, handle_decimals = True)
splits = split_text(text)
is_inc, diff = find_if_incrementing_or_repeating(splits)
if is_inc and diff is not None and diff != 0:
return True, False
elif is_inc:
return False, True

splits = split_text(text, handle_decimals = False)
splits = split_text(text, split_type="repeating")
is_inc, diff = find_if_incrementing_or_repeating(splits)
if is_inc and diff is not None and diff != 0:
return True, False
elif is_inc:
if is_inc: # we don't have incrementing cases when we split by characters
return False, True
else:
return False, False
Expand Down Expand Up @@ -193,4 +183,4 @@ def pattern_sequences_filter(dataset: DataFrame, _) -> DataFrame:
samp = r"""
"A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1"
"""
print(incrementing_sequences_filter_wrapper(samp))
print(is_pattern(samp))
4 changes: 2 additions & 2 deletions model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
EXPERIMENT_ROOT = "experiments"
MODEL_SIZE = "12b"
DATA_SCHEME = "deduped"
GENERATION_HF_DATASET_NAME = "usvsnsp/generation-semantic-filters"
GENERATION_HF_DATASET_NAME = "usvsnsp/semantic-filters"

"""
Feature Catalog
Expand Down Expand Up @@ -107,7 +107,7 @@ def classify_row(row: pd.Series) -> str:
"""
Model Training Hyper-parameters
"""
GLOBAL_SEED = 80
GLOBAL_SEED = 1024

TRAIN_SIZE = 0.8
VALIDATION_SIZE = 0.1
Expand Down
Loading

0 comments on commit 720fc4b

Please sign in to comment.