diff --git a/docs/source/conf.py b/docs/source/conf.py index dfd9ff4..65c56d6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,53 +10,43 @@ import sys from datetime import datetime -# -- Path setup -------------------------------------------------------------- - -# Add the path to your source code here. +# Path setup sys.path.insert(0, os.path.abspath("../../src")) -# -- Project information ----------------------------------------------------- - +# Project information PROJECT = "MELTs" AUTHOR = "Thu Nguyen Hoang Anh" -COPYRIGHT = f"{datetime.datetime.now().year}, {AUTHOR}" +COPYRIGHT = f"{datetime.now().year}, {AUTHOR}" -# The version info for the project -VERSION = "0.1" # Short version (e.g., '0.1') -RELEASE = "0.1" # Full version (e.g., '0.1.0') +# The full version, including alpha/beta/rc tags +RELEASE = "0.1" -# -- General configuration --------------------------------------------------- +# General configuration +MASTER_DOC = "index" -MASTER_DOC = "index" # The name of the master document - -# Sphinx extensions to use +# Sphinx extension modules as strings, can be built-in or custom EXTENSIONS = [ - "sphinx.ext.duration", # Measure build time - "sphinx.ext.autodoc", # Include documentation from docstrings - "sphinx.ext.coverage", # Check for documentation coverage - "sphinx.ext.doctest", # Test embedded doctests - "sphinx_rtd_theme", # Read the Docs theme + "sphinx.ext.duration", + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx_rtd_theme", + "sphinx.ext.doctest", ] -# Mock import for autodoc +# List of modules to mock during autodoc generation AUTODOC_MOCK_IMPORTS = ["pyemd"] # Paths that contain templates TEMPLATES_PATH = ["_templates"] -# Patterns to ignore when looking for source files +# List of patterns to ignore when looking for source files EXCLUDE_PATTERNS = [] # Sort members alphabetically in the autodoc AUTODOC_MEMBER_ORDER = "alphabetical" -# Theme to use for HTML and HTML Help pages +# Options for HTML output HTML_THEME = "sphinx_rtd_theme" -# Theme options for customizing the appearance of the theme -HTML_THEME_OPTIONS = { - # You can add theme-specific options here -} - -# Paths that contain custom static files (e.g., style sheets) +# Paths for custom static files (like style sheets) HTML_STATIC_PATH = ["_static"] diff --git a/src/melt/__main__.py b/src/melt/__main__.py index 689893e..7e2dc05 100644 --- a/src/melt/__main__.py +++ b/src/melt/__main__.py @@ -1,94 +1,18 @@ -""" -This script initializes NLP models and runs the main function from the 'cli' module. - -The script performs the following tasks: -1. Downloads the 'punkt' tokenizer models using nltk. -2. Loads the spaCy 'en_core_web_sm' model, downloading it if necessary. -3. Imports and executes the 'main' function from the 'cli' module. - -If any module or function cannot be imported, appropriate error messages are displayed. -""" - -import logging +"Main" import spacy import nltk -from spacy.cli import download as spacy_download -from typing import NoReturn - -# Configure logging with a descriptive name for the logger -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(message)s", - level=logging.INFO -) -logger = logging.getLogger("nlp_utils") - -def download_nltk_resources() -> NoReturn: - """Download the necessary NLTK resources. - - Logs success or failure messages. - """ - try: - with nltk.download('punkt'): - logger.info("Successfully downloaded NLTK 'punkt' resource.") - except Exception as error: - logger.error("Failed to download NLTK resources: %s", error) - raise - - -def load_spacy_model(model_name: str = "en_core_web_sm") -> spacy.language.Language: - """Load and return the spaCy model, downloading it if necessary. - - Logs success or failure messages during the model loading process. - - Args: - model_name (str): The name of the spaCy model to load. - - Returns: - spacy.language.Language: The loaded spaCy model. - """ - try: - model = spacy.load(model_name) - logger.info("Successfully loaded spaCy model: %s", model_name) - except OSError: - logger.warning("spaCy model '%s' not found. Downloading...", model_name) - spacy_download(model_name) - model = spacy.load(model_name) - logger.info("Successfully downloaded and loaded spaCy model: %s", model_name) - except Exception as error: - logger.error("Failed to load spaCy model: %s", error) - raise - return model - - -def execute_cli_main() -> None: - """Execute the 'main' function from the CLI module. - - Logs success or failure messages about the import process and execution. - """ - try: - from cli import main as cli_main - logger.info("Successfully imported 'main' from 'cli' module.") - except ImportError as import_error: - logger.error("ImportError: %s", import_error) - try: - import cli - cli_main = cli.main - logger.info("Successfully imported 'cli' module directly.") - except ImportError as inner_import_error: - logger.critical("Failed to import 'cli' module: %s", inner_import_error) - raise - cli_main() - - -def main() -> None: - """Main function to set up resources and execute the CLI. +from melt.cli import main - Ensures proper logging and execution flow. - """ - download_nltk_resources() - load_spacy_model() - execute_cli_main() +nltk.download('punkt_tab') +try: + spacy.load("en_core_web_sm") +except OSError: + print( + "Downloading the spacy en_core_web_sm model\n" + "(don't worry, this will only happen once)" + ) + from spacy.cli import download + download("en_core_web_sm") -if __name__ == "__main__": - main() +main() diff --git a/src/melt/cli.py b/src/melt/cli.py index e1de937..e959b9d 100644 --- a/src/melt/cli.py +++ b/src/melt/cli.py @@ -1,18 +1,9 @@ -""" -This script initializes and runs the text generation pipeline using spaCy, -transformers, and dotenv. It also handles downloading the spaCy 'en_core_web_sm' -model if it is not already present. - -The main function is responsible for: -1. Loading environment variables. -2. Parsing script arguments. -3. Running the generation process with the parsed arguments. -""" -try: - import spacy -except ImportError as e: - print(f"Failed to import 'spacy': {e}") - +"Cli" +import spacy +from transformers import HfArgumentParser +from dotenv import load_dotenv +from melt.script_arguments import ScriptArguments +from melt.generation import generation try: spacy.load("en_core_web_sm") except OSError: @@ -20,56 +11,18 @@ "Downloading the spacy en_core_web_sm model\n" "(don't worry, this will only happen once)" ) - try: - from spacy.cli import download - download("en_core_web_sm") + from spacy.cli import download - except ImportError as e: - print(f"Failed to import 'spacy.cli': {e}") -try: - from transformers import HfArgumentParser -except ImportError as e: - print(f"Failed to import 'transformers': {e}") + download("en_core_web_sm") -try: - from dotenv import load_dotenv -except ImportError as e: - print(f"Failed to import 'dotenv': {e}") - -try: - from .script_arguments import ScriptArguments -except ImportError as e: - print(f"Failed to import 'ScriptArguments' from 'script_arguments': {e}") -try: - from .generation import generation -except ImportError as e: - print(f"Failed to import 'generation' from 'generation': {e}") -def main(): - """ - The main function that initializes the environment, parses script arguments, - and triggers the text generation process. - This function performs the following steps: - 1. Loads environment variables using `load_dotenv()`. - 2. Creates an argument parser for `ScriptArguments` using `HfArgumentParser`. - 3. Parses the arguments into data classes. - 4. Calls the `generation` function with the parsed arguments to perform the text generation. +# from .to_sheet import to_sheet +# from .to_sheet_std import to_sheet_std - Returns: - None - """ +def main(): + "CLI" load_dotenv() - - # Ensure spaCy model is available - ensure_spacy_model() - - # Parse command-line arguments parser = HfArgumentParser(ScriptArguments) args = parser.parse_args_into_dataclasses()[0] - - # Execute the generation function with parsed arguments generation(args) - -if __name__ == "__main__": - main() diff --git a/src/melt/generation.py b/src/melt/generation.py index a07ccf0..64a0a7d 100644 --- a/src/melt/generation.py +++ b/src/melt/generation.py @@ -1,69 +1,14 @@ -""" -This module provides functionality for evaluating and -generating data using specified pipelines and datasets. - -The `generation` function is the main entry point of this script. It performs the following tasks: -1. Initializes the seed for reproducibility. -2. Loads and processes the dataset using `DatasetWrapper`. -3. Sets up directories for saving results if they don't already exist. -4. Handles continuation of inference from a previous run if specified. -5. Creates a DataLoader for batching dataset examples. -6. Initializes the evaluation pipeline (`EvalPipeline`). -7. Runs the evaluation pipeline and saves the results to JSON files. - -The script is designed to work with various configurations -specified in the `script_args` parameter, including options for -few-shot prompting and continuing from previous results. - -Modules used: -- `os`: For file and directory operations. -- `.tools.data`: Contains `DatasetWrapper` for -dataset management. -- `.tools.pipelines`: Contains `EvalPipeline` for -evaluation processes. -- `.tools.utils.utils`: Provides utility functions such as -`save_to_json`, `set_seed`, and `read_json`. -- `torch.utils.data`: For data loading with `DataLoader`. -""" +"Generation" import os +import sys from torch.utils.data import DataLoader -from .tools.data import DatasetWrapper -from .tools.pipelines import EvalPipeline -from .tools.utils.utils import save_to_json, set_seed, read_json - - +from melt.tools.data import DatasetWrapper +from melt.tools.pipelines import EvalPipeline +from melt.tools.utils.utils import save_to_json, set_seed, read_json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def generation(script_args): - """ - Executes the data generation process based on the provided script arguments. - - This function performs the following steps: - 1. Sets the random seed for reproducibility using `set_seed`. - 2. Loads and optionally processes the dataset using `DatasetWrapper`. - 3. Constructs filenames for saving generation results and metrics based on the script arguments. - 4. Creates necessary directories for saving results if they don't already exist. - 5. Determines the starting index and results to continue - inference from a previous run if specified. - 6. Initializes a `DataLoader` for batching the dataset examples. - 7. Initializes an `EvalPipeline` for evaluating the data. - 8. Runs the evaluation pipeline and saves the results using the `save_results` function. - Args: - script_args (ScriptArguments): An object containing the configuration - and parameters for the data generation process. - - seed (int): Random seed for reproducibility. - - smoke_test (bool): Flag to indicate if a smaller subset - of data should be used for testing. - - dataset_name (str): Name of the dataset. - - model_name (str): Name of the model. - - output_dir (str): Directory to save generation results. - - output_eval_dir (str): Directory to save evaluation metrics. - - continue_infer (bool): Flag to continue inference from a previous run. - - per_device_eval_batch_size (int): Batch size for evaluation. - - fewshot_prompting (bool): Flag for few-shot prompting. - - Returns: - None - """ + "Generation" set_seed(script_args.seed) # Load dataset (you can process it here) @@ -76,19 +21,29 @@ def generation(script_args): dataset_wrapper.dataset_testing.select(range(n_examples)) ) ds_exact_name = ( - script_args.dataset_name.split("/")[-1] + script_args.lang + "_" - + script_args.model_name.split("/")[-1] - + f"_pt{dataset_wrapper.prompting_strategy}" - + ("_fewshot" if script_args.fewshot_prompting else "") + + dataset_wrapper.dataset_info.task + + "_" + + script_args.dataset_name.split("/")[-1].replace("_", "-") + + "_" + + script_args.model_name.split("/")[-1].replace("_", "-") + + "_" + + script_args.prompt_type + + "_" + + script_args.category + + "_" + + str(script_args.num_fs_shot) + + "_pt" + dataset_wrapper.prompting_strategy + f"_seed{script_args.seed}" - ) +) + json_file = os.path.join( script_args.output_dir, f"generations_{ds_exact_name}.json" ) metric_file = os.path.join( - script_args.output_eval_dir, f"metrics_{ds_exact_name}.json" + script_args.output_eval_dir, f"{ds_exact_name}.json" ) # Save results diff --git a/src/melt/script_arguments.py b/src/melt/script_arguments.py index e1abfc0..d46a878 100644 --- a/src/melt/script_arguments.py +++ b/src/melt/script_arguments.py @@ -1,68 +1,11 @@ -""" -This module defines the `ScriptArguments` class used for configuring script parameters. - -The `ScriptArguments` class utilizes Python's `dataclass` to provide a -structured way to handle various configuration settings -needed for running the script. The fields within this -class include parameters for model and dataset configuration, -precision and quantization settings, output directories, and inference parameters. - -Class: - ScriptArguments: A data class that encapsulates various - configuration parameters for the script. - - -Attributes: - model_name (str): The model name to train or use, typically from the Hugging Face hub. - dataset_name (str): The dataset name to use for training or evaluation. - use_4bit (Optional[bool]): Whether to use 4-bit precision for model loading. - bnb_4bit_compute_dtype (Optional[str]): Data type for 4-bit model computation. - bnb_4bit_quant_type (Optional[str]): Quantization type (e.g., fp4 or nf4). - use_nested_quant (Optional[bool]): Whether to use nested quantization. - cpu_offload_gb (int): Amount of memory to offload to CPU. - lang (str): Language of the dataset (e.g., vi, ind, kr). - dataset_dir (str): Directory for loading datasets. - config_dir (str): Directory for configuration files. - output_dir (str): Directory for saving model predictions and checkpoints. - output_eval_dir (str): Directory for saving evaluation metrics. - per_device_eval_batch_size (Optional[int]): Batch size per GPU for evaluation. - dtype (str): Data type for model loading. - ms_hub_token (Optional[str]): Token for Microsoft Hub. - hf_hub_token (Optional[str]): Token for Hugging Face Hub. - smoke_test (Optional[bool]): Whether to run a smoke test on a small dataset. - fewshot_prompting (Optional[bool]): Whether to enable few-shot prompting. - num_fs (Optional[int]): Number of samples for few-shot learning. - seed (Optional[int]): Random seed for reproducibility. - continue_infer (Optional[bool]): Whether to continue a previous inference process. - wtype (str): Type of wrapper to use (e.g., hf, tgi, azuregpt, gemini). - ptemplate (Optional[str]): Prompting template to use (e.g., llama-2, mistral). - device (str): CUDA device to use. - n_bootstrap (int): Number of bootstrap samples. - p_bootstrap (float): Probability for bootstrap sampling. - bs (int): Bias metric. - -This class serves as a configuration container to manage and pass -parameters throughout the script efficiently. -""" - +"script" from dataclasses import dataclass, field -from typing import Optional -from typing import Dict +from typing import Optional, Dict, Union @dataclass class ModelConfig: """ Configuration class for model settings. - - Attributes: - model_name (str): The name of the model to train from the Hugging Face hub. - dataset_name (str): The instruction dataset to use. - lang (str): Language of the dataset (e.g., vi, ind, kr, ...). - dataset_dir (str): Default directory for loading datasets. - config_dir (str): Directory containing LLM template, - prompt template, and generation configuration. - output_dir (str): Directory for storing model predictions and checkpoints. - output_eval_dir (str): Directory for saving metric scores. """ model_name: str = field( default="meta-llama/Llama-2-7b-chat-hf", @@ -99,22 +42,7 @@ class ModelConfig: class BitsAndBytesConfig: """ Configuration class for bits and bytes parameters. - - This class contains settings related to the precision and quantization of - base models, including activation of 4-bit precision, compute data type, - quantization type, nested quantization, and CPU offloading settings. - - Attributes: - use_4bit (Optional[bool]): Whether to activate 4-bit precision base model loading. - bnb_4bit_compute_dtype (Optional[str]): Compute data - type for 4-bit base models (e.g., 'bfloat16'). - bnb_4bit_quant_type (Optional[str]): Quantization type - used for 4-bit models (e.g., 'fp4' or 'nf4'). - use_nested_quant (Optional[bool]): Whether to activate - nested quantization for 4-bit base models. - cpu_offload_gb (int): Amount of memory to offload to CPU, in gigabytes. """ - use_4bit: Optional[bool] = field( default=False, metadata={"help": "Activate 4-bit precision base model loading"} @@ -140,15 +68,6 @@ class BitsAndBytesConfig: class InferenceConfig: """ Configuration class for inference settings. - - Attributes: - tokens (Dict[str, Optional[str]]): Configuration for tokens - including Microsoft Hub and Hugging Face Hub tokens. - settings (Dict[str, Optional]): Inference settings including - smoke test, few-shot prompting, number of few-shot samples, - random seed, and whether to continue previous inference. - wrapper (Dict[str, str]): Wrapper configuration - including the type of wrapper and prompting template. """ tokens: Dict[str, Optional[str]] = field( default_factory=lambda: { @@ -157,7 +76,7 @@ class InferenceConfig: }, metadata={"help": "Token configuration"} ) - settings: Dict[str, Optional] = field( + settings: Dict[str, Union[bool, int]] = field( default_factory=lambda: { "smoke_test": False, "fewshot_prompting": False, @@ -175,22 +94,9 @@ class InferenceConfig: metadata={"help": "Wrapper configuration"} ) -def default_general_config(): +def default_general_config() -> Dict[str, Union[int, str]]: """ Returns a dictionary with default configuration values for general settings. - - This function provides default values for various configuration parameters - related to general settings, such as batch size, data type, device, and - other metrics. - - Returns: - dict: A dictionary containing default values for: - - per_device_eval_batch_size: The batch size per GPU for evaluation. - - dtype: The data type for model loading. - - device: The CUDA device to be used. - - n_bootstrap: The number of bootstrap iterations. - - p_bootstrap: The probability for bootstrap sampling. - - bs: Bias metric. """ return { "per_device_eval_batch_size": 1, @@ -205,15 +111,44 @@ def default_general_config(): class ScriptArguments: """ Configuration class for script arguments. - - Attributes: - model_config (ModelConfig): Configuration for model settings. - bits_and_bytes (BitsAndBytesConfig): Configuration for bits and bytes parameters. - inference_config (InferenceConfig): Configuration for inference settings. - general_config (Dict[str, Optional]): General configuration settings including - batch size, data type, device, and other metrics. """ model_config: ModelConfig = field(default_factory=ModelConfig) bits_and_bytes: BitsAndBytesConfig = field(default_factory=BitsAndBytesConfig) inference_config: InferenceConfig = field(default_factory=InferenceConfig) - general_config: Dict[str, Optional] = field(default_factory=default_general_config) + general_config: Dict[str, Union[int, str, float]] = field( + default_factory=default_general_config + ) + + @property + def seed(self) -> int: + "seed" + return self.inference_config.settings['seed'] + @seed.setter + def seed(self, value: int): + "seed" + self.inference_config.settings['seed'] = value + + # Add methods to access nested attributes if needed + @property + def dataset_name(self) -> str: + "dataset" + return self.model_config.dataset_name + + @property + def lang(self) -> str: + "lang" + return self.model_config.lang + + # You can add similar properties for other nested attributes if needed + @property + def dataset_dir(self) -> str: + "dataset" + return self.model_config.dataset_dir + @property + def output_eval_dir(self) -> str: + "output" + return self.model_config.output_eval_dir + @property + def config_dir(self) -> str: + "config" + return self.model_config.config_dir diff --git a/src/melt/tools/data/__init__.py b/src/melt/tools/data/__init__.py index e8c4201..c9c16be 100644 --- a/src/melt/tools/data/__init__.py +++ b/src/melt/tools/data/__init__.py @@ -1,5 +1,5 @@ -"""Module providing a function printing python version.""" -from .dataset import DatasetWrapper +"init" +from melt.tools.data.dataset import DatasetWrapper __all__ = [ "DatasetWrapper", diff --git a/src/melt/tools/data/dataset.py b/src/melt/tools/data/dataset.py index 1dc16a7..1171b4f 100644 --- a/src/melt/tools/data/dataset.py +++ b/src/melt/tools/data/dataset.py @@ -1,34 +1,26 @@ -""" -This module provides the DatasetWrapper class for loading and managing datasets, -as well as generating prompts based on a configured strategy. -""" - +"WRAPPER" import os +import sys import json import ast -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple from argparse import Namespace -from .parser import get_dataset_list +from melt.tools.data.parser import get_dataset_list -def load_a_dataset(): +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +def load_a_dataset() -> Tuple[Any, Any]: """ Placeholder function for loading a dataset. - - Returns: - tuple: (training_data, testing_data) + Returns a tuple of (training_dataset, testing_dataset). """ - # Implement the actual dataset loading logic here - return None, None + # Implement dataset loading logic here + training_dataset = None # Replace with actual loading logic + testing_dataset = None # Replace with actual loading logic + return training_dataset, testing_dataset def eval_keys(keys: str | list[str]) -> callable: """ Returns a function that evaluates the provided keys in the dictionary. - - Args: - keys (str | list[str]): A key or list of keys to evaluate in the dictionary. - - Returns: - callable: A function to evaluate the keys in the dictionary. """ def eval_x(x: Dict[str, Any]) -> Dict[str, Any]: if isinstance(keys, str): @@ -37,20 +29,15 @@ def eval_x(x: Dict[str, Any]) -> Dict[str, Any]: for key in keys: x[key] = ast.literal_eval(x[key]) return x - return eval_x class DatasetWrapper: """ - A wrapper class for loading datasets, configuring them, and generating prompts - based on the prompting strategy. + A wrapper class for managing datasets and generating prompts. """ def __init__(self, args: Namespace) -> None: """ Initializes the DatasetWrapper with the provided arguments. - - Args: - args (Namespace): The arguments containing dataset name and configuration. """ self.args = args self.datasets: Dict[str, Optional[Any]] = { @@ -62,14 +49,9 @@ def __init__(self, args: Namespace) -> None: self.get_dataset_config() self.prompting_strategy: int = self.dataset_info['prompting_strategy'] self.get_prompt() - def get_prompt(self) -> None: """ - Loads the prompt template and calibration instructions based on the dataset - and prompting strategy. - - Raises: - ValueError: If the prompting strategy is not supported. + Get the prompt template and calibration instruction based on the prompting strategy. """ prompt_config_path = os.path.join( self.args.config_dir, self.args.lang, "prompt_template.json" @@ -78,16 +60,13 @@ def get_prompt(self) -> None: prompt_config = json.load(f) prompt_template = prompt_config["PROMPT_TEMPLATE"] calibration_instruction = prompt_config["CALIBRATION_INSTRUCTION"] - if self.prompting_strategy not in [0, 1, 2, 3]: raise ValueError("Prompting strategy is not supported") - task = self.dataset_info['task'] self.prompt = prompt_template[task][self.prompting_strategy] self.calibration_prompt = ( calibration_instruction.get(task, {}).get(self.prompting_strategy, None) ) - def get_dataset_config(self) -> None: """ Loads the dataset configuration and sets up the training and testing datasets. @@ -97,31 +76,24 @@ def get_dataset_config(self) -> None: dataset_dir=os.path.join(self.args.config_dir, self.args.lang), )[0] self.datasets['training'], self.datasets['testing'] = load_a_dataset() - def get_dataset_testing(self) -> Any: """ Returns the testing dataset if available. - - Raises: - ValueError: If the testing dataset is not available. - - Returns: - Any: The testing dataset. """ if self.datasets['testing'] is None: raise ValueError("Testing dataset is not available") return self.datasets['testing'] - def get_dataset_training(self) -> Any: """ Returns the training dataset if available. - + Raises: ValueError: If the training dataset is not available. - + Returns: Any: The training dataset. """ if self.datasets['training'] is None: raise ValueError("Training dataset is not available") return self.datasets['training'] + diff --git a/src/melt/tools/data/loader.py b/src/melt/tools/data/loader.py index 2e25509..fa4ccaf 100644 --- a/src/melt/tools/data/loader.py +++ b/src/melt/tools/data/loader.py @@ -1,130 +1,84 @@ -"""Module for loading datasets from various sources.""" - +"Loader" import os from pathlib import Path -from typing import Tuple, Any - -# Third-party imports -try: - from transformers.utils.versions import require_version -except ImportError: - require_version = None - -try: - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE -except ImportError: - MsDataset = None - MS_DATASETS_CACHE = None - -try: - from datasets import load_dataset -except ImportError: - load_dataset = None - -# First-party imports -try: - from melt.utils.constants import FILEEXT2TYPE -except ImportError: - FILEEXT2TYPE = {} - -def _load_single_dataset(dataset_attr, args, mode) -> Tuple[Any, Any]: - """ - Load a single dataset based on the given attributes and mode. - - Args: - dataset_attr: Attributes of the dataset to load. - args: Arguments containing configuration options. - mode: The mode of the dataset (e.g., 'train', 'test'). - - Returns: - A tuple containing the loaded dataset and its attributes. +from transformers.utils.versions import require_version +from modelscope import MsDataset +from modelscope.utils.config_ds import MS_DATASETS_CACHE +from datasets import load_dataset +from melt.tools.utils.constants import FILEEXT2TYPE + +def load_a_dataset(dataset_attr, args): + """Load dataset for training and testing""" + dataset_training, _ = _load_single_dataset( + dataset_attr, args, dataset_attr.train_split + ) + dataset_testing, _ = _load_single_dataset( + dataset_attr, args, dataset_attr.test_split + ) + return dataset_training, dataset_testing - Raises: - NotImplementedError: If the load type is unknown. - ImportError: If required modules are not available. - """ +def _load_single_dataset(dataset_attr, args, mode): print(f"Loading {mode} dataset {dataset_attr}...") - - load_functions = { - "hf_hub": _load_from_hf_hub, - "ms_hub": _load_from_ms_hub, - "file": _load_from_file + load_config = _get_load_config(dataset_attr, args, mode) + if dataset_attr.load_from == "ms_hub": + dataset = _load_from_ms_hub(load_config, args, mode) + else: + dataset = _load_from_hf_hub(load_config, args, mode) + return dataset, dataset_attr +def _get_load_config(dataset_attr, args, mode): + config = { + "data_path": None, + "data_name": None, + "data_dir": None, + "data_files": None, } - - load_func = load_functions.get(dataset_attr.load_from) - if not load_func: + if dataset_attr.load_from in ["hf_hub", "ms_hub"]: + config["data_path"] = dataset_attr.dataset_name + config["data_name"] = dataset_attr.subset + config["data_dir"] = dataset_attr.folder + elif dataset_attr.load_from == "file": + config["data_files"], config["data_path"] = _get_file_config(dataset_attr, args, mode) + else: raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") - - return load_func(dataset_attr, args, mode) - -def _load_from_hf_hub(dataset_attr, args, mode): - if load_dataset is None: - raise ImportError("The 'datasets' library is not installed.") - return load_dataset( - path=dataset_attr.dataset_name, - name=dataset_attr.subset, - data_dir=dataset_attr.folder, - split=mode, - token=args.hf_hub_token, - trust_remote_code=True, - ), dataset_attr - -def _load_from_ms_hub(dataset_attr, args, mode): - if MsDataset is None or MS_DATASETS_CACHE is None: - raise ImportError("ModelScope packages are not installed or not available.") - - if require_version is None: - raise ImportError("The 'transformers' library is not installed.") - - require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") - - dataset = MsDataset.load( - dataset_name=dataset_attr.dataset_name, - subset_name=dataset_attr.subset, - data_dir=dataset_attr.folder, - split=mode, - cache_dir=MS_DATASETS_CACHE, - token=args.ms_hub_token, - ) - - if isinstance(dataset, MsDataset): - dataset = dataset.to_hf_dataset() - - return dataset, dataset_attr - -def _load_from_file(dataset_attr, args, mode): + return config +def _get_file_config(dataset_attr, args, mode): local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name) if not os.path.isdir(local_path): raise ValueError(f"Directory {local_path} not found.") - data_files = {} data_path = None - for file_name in os.listdir(local_path): if Path(file_name).stem.split("_")[-1] == mode: data_files[mode] = os.path.join(local_path, file_name) - file_ext = file_name.split(".")[-1] - current_data_path = FILEEXT2TYPE.get(file_ext) - + file_type = FILEEXT2TYPE.get(file_name.split(".")[-1], None) if data_path is None: - data_path = current_data_path - elif data_path != current_data_path: + data_path = file_type + elif data_path != file_type: raise ValueError("File types should be identical.") - if not data_files: - raise ValueError("No appropriate file found.") - + raise ValueError("No matching files found.") if data_path is None: - raise ValueError(f"Allowed file types: {', '.join(FILEEXT2TYPE.keys())}.") - - if load_dataset is None: - raise ImportError("The 'datasets' library is not installed.") - + raise ValueError(f"Unable to determine file type for {local_path}.") + return data_files, data_path +def _load_from_ms_hub(config, args, mode): + require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + dataset = MsDataset.load( + dataset_name=config["data_path"], + subset_name=config["data_name"], + data_dir=config["data_dir"], + data_files=config["data_files"], + split=mode, + cache_dir=MS_DATASETS_CACHE, + token=args.ms_hub_token, + ) + return dataset.to_hf_dataset() if isinstance(dataset, MsDataset) else dataset +def _load_from_hf_hub(config, args, mode): return load_dataset( - path=data_path, - data_files=data_files, + path=config["data_path"], + name=config["data_name"], + data_dir=config["data_dir"], + data_files=config["data_files"], split=mode, token=args.hf_hub_token, trust_remote_code=True, - ), dataset_attr + ) diff --git a/src/melt/tools/data/parser.py b/src/melt/tools/data/parser.py index 26af8a1..27a3754 100644 --- a/src/melt/tools/data/parser.py +++ b/src/melt/tools/data/parser.py @@ -1,151 +1,86 @@ -""" -Module for parsing and managing dataset attributes and configurations. - -This module provides functionality to load dataset configurations from -a JSON file and manage attributes related to datasets. -""" - +"parser" import json import os from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Sequence - -# Assuming this is the correct import path, adjust if necessary -try: - from melt.utils.constants import DATA_CONFIG -except ImportError: - DATA_CONFIG = "data_config.json" # Fallback value - +from melt.tools.utils.constants import DATA_CONFIG @dataclass -class ColumnGroup: - """Group of related column attributes.""" - query: str = "input" - response: str = "output" - history: Optional[str] = None - context: str = "context" - -@dataclass -class ColumnAttributes: - """Attributes related to dataset columns.""" - primary: ColumnGroup = field(default_factory=ColumnGroup) - answer: str = "answer" - passages: str = "passages" - source: str = "source" - target: str = "target" - options: str = "options" - type_id: str = "type_id" - -@dataclass -class SplitAttributes: - """Attributes related to dataset splits.""" - train_split: str = "train" - test_split: str = "test" - +class SplitConfig: + "class" + train: str = "train" + test: str = "test" @dataclass class DatasetConfig: - """Configuration settings for the dataset.""" - task: Optional[str] = None - prompting_strategy: int = 0 + """Configuration for a dataset.""" subset: Optional[str] = None - label: Optional[List[Any]] = None - random: bool = False folder: Optional[str] = None - num_samples: Optional[int] = None - -@dataclass -class DatasetMeta: - """Metadata for managing and loading datasets.""" - config: DatasetConfig = field(default_factory=DatasetConfig) - columns: ColumnAttributes = field(default_factory=ColumnAttributes) - splits: SplitAttributes = field(default_factory=SplitAttributes) - + task: Optional[str] = None + label: Optional[List] = None + splits: SplitConfig = field(default_factory=SplitConfig) + prompting_strategy: int = 0 + sampling: Dict[str, Any] = field(default_factory=lambda: {"random": False, "num_samples": None}) @dataclass class DatasetAttr: - """Dataset attributes for managing and loading datasets.""" + """Dataset attributes.""" load_from: Literal["hf_hub", "ms_hub", "file"] dataset_name: str - meta: DatasetMeta = field(default_factory=DatasetMeta) - extra_attributes: Dict[str, Any] = field(default_factory=dict) - + config: DatasetConfig = field(default_factory=DatasetConfig) + columns: Dict[str, str] = field(default_factory=lambda: { + "query": "input", + "response": "output", + "history": None, + "context": "context", + "answer": "answer", + "passages": "passages", + "source": "source", + "target": "target", + "options": "options", + "type_id": "type_id" + }) def __repr__(self) -> str: return self.dataset_name - - def set_attr(self, key: str, obj: Dict[str, Any], default: Any = None) -> None: - """Set attribute value from a dictionary or use default.""" - if hasattr(self.meta, key): - setattr(self.meta, key, obj.get(key, default)) - else: - self.extra_attributes[key] = obj.get(key, default) - +def load_dataset_config(config_path: str) -> Dict[str, Any]: + "function" + try: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError as err: + raise FileNotFoundError(f"Config file not found: {config_path}") from err + except json.JSONDecodeError as err: + raise ValueError(f"Invalid JSON in config file: {config_path}") from err +def create_dataset_attr(info: Dict[str, Any]) -> DatasetAttr: + "create" + if "ms_hub_url" in info or ("hf_hub_url" not in info and "file_name" not in info): + dataset_attr = DatasetAttr("ms_hub", dataset_name=info.get("ms_hub_url", "")) + elif "hf_hub_url" in info: + dataset_attr = DatasetAttr("hf_hub", dataset_name=info["hf_hub_url"]) + else: + dataset_attr = DatasetAttr("file", dataset_name=info["file_name"]) + config = dataset_attr.config + config.subset = info.get("subset") + config.folder = info.get("folder") + config.task = info.get("task") + config.label = info.get("label") + config.prompting_strategy = info.get("prompting_strategy", 0) + config.splits.train = info.get("train_split", "train") + config.splits.test = info.get("test_split", "test") + config.sampling["random"] = info.get("random", False) + config.sampling["num_samples"] = info.get("num_samples") + if "columns" in info: + for column in dataset_attr.columns: + dataset_attr.columns[column] = info["columns"].get(column, column) + return dataset_attr def get_dataset_list( dataset_names: Optional[Sequence[str]], dataset_dir: str ) -> List[DatasetAttr]: - """ - Get the attributes of the datasets. - - Args: - dataset_names: Sequence of dataset names to process. - dataset_dir: Directory containing the dataset configurations. - - Returns: - List of DatasetAttr objects. - - Raises: - ValueError: If the config file cannot be opened or a dataset is undefined. - """ - dataset_names = dataset_names or [] + """Gets the attributes of the datasets.""" + if not dataset_names: + return [] config_path = os.path.join(dataset_dir, DATA_CONFIG) - - try: - with open(config_path, "r", encoding="utf-8") as f: - dataset_info = json.load(f) - except (IOError, json.JSONDecodeError) as err: - if dataset_names: - raise ValueError( - f"Cannot open or parse {config_path} due to {str(err)}" - ) from err - dataset_info = {} - - dataset_list: List[DatasetAttr] = [] + dataset_info = load_dataset_config(config_path) + dataset_list = [] for name in dataset_names: if name not in dataset_info: raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}") - - dataset_attr = create_dataset_attr(name, dataset_info[name]) - set_dataset_attributes(dataset_attr, dataset_info[name]) - dataset_list.append(dataset_attr) - + dataset_list.append(create_dataset_attr(dataset_info[name])) return dataset_list - -def create_dataset_attr(name: str, info: Dict[str, Any]) -> DatasetAttr: - """Create a DatasetAttr object based on the dataset information.""" - load_from = "ms_hub" if "ms_hub_url" in info or "hf_hub_url" not in info else "hf_hub" - dataset_name = info.get("ms_hub_url", info.get("hf_hub_url", name)) - return DatasetAttr(load_from=load_from, dataset_name=dataset_name) - -def set_dataset_attributes(dataset_attr: DatasetAttr, info: Dict[str, Any]) -> None: - """Set attributes for a DatasetAttr object.""" - config_attributes = [ - 'task', 'prompting_strategy', 'subset', 'label', 'random', - 'folder', 'num_samples' - ] - for attr in config_attributes: - dataset_attr.set_attr(attr, info, default=getattr(dataset_attr.meta.config, attr)) - - # Set column attributes if present - if "columns" in info: - for column in ColumnAttributes.__annotations__.keys(): - dataset_attr.set_attr( - column, - info["columns"], - default=getattr(dataset_attr.meta.columns, column) - ) - - # Set split attributes if present - if "splits" in info: - for split in SplitAttributes.__annotations__.keys(): - dataset_attr.set_attr( - split, - info["splits"], - default=getattr(dataset_attr.meta.splits, split) - ) diff --git a/src/melt/tools/metrics/base.py b/src/melt/tools/metrics/base.py index 7dfd1ec..10ce971 100644 --- a/src/melt/tools/metrics/base.py +++ b/src/melt/tools/metrics/base.py @@ -2,7 +2,7 @@ This module contains base classes for metrics processing. """ -from .post_process import get_answer_auto_from_text +from melt.tools.metrics.post_process import get_answer_auto_from_text class BaseMetric: """ @@ -12,10 +12,6 @@ class BaseMetric: def __init__(self, data=None, args=None): """ Initializes the BaseMetric with optional data and arguments. - - Args: - data (optional): Data related to the metric. Defaults to None. - args (optional): Arguments for processing. Defaults to None. """ self.data = data self.args = args @@ -23,15 +19,6 @@ def __init__(self, data=None, args=None): def _get_answer(self, text: str, args) -> str: """ Process a text and extract an answer based on certain arguments. - - Args: - text (str): A string containing the text from which the answer is \ - to be extracted. - args: Arguments containing 'key_answer', 'class_names', and other \ - parameters required for extraction. - - Returns: - str: The extracted answer. """ return get_answer_auto_from_text( text=text, @@ -43,17 +30,11 @@ def _get_answer(self, text: str, args) -> str: def set_data(self, data): """ Sets the data for the metric. - - Args: - data: The data to be set. """ self.data = data def get_data(self): """ Gets the data for the metric. - - Returns: - The current data. """ return self.data diff --git a/src/melt/tools/metrics/basic_metrics.py b/src/melt/tools/metrics/basic_metrics.py index 68abc42..ae02e15 100644 --- a/src/melt/tools/metrics/basic_metrics.py +++ b/src/melt/tools/metrics/basic_metrics.py @@ -1,19 +1,7 @@ -""" -This module provides basic metrics for evaluating text similarity and overlap. +"basic_metrics" +from nltk.metrics.scores import f_measure +from melt.tools.metrics.utils import normalize_text -It includes functions for exact match and F1 score calculations between -predicted text and gold standard text. -""" - -from .utils import normalize_text - -try: - from nltk.tokenize import word_tokenize - import nltk - nltk.download('punkt', quiet=True) -except ImportError as e: - print(f"Error importing NLTK: {e}") - # Handle the error or raise an exception def exact_match(gold: str, pred: str) -> float: """Calculates whether the predicted text (pred) @@ -31,10 +19,11 @@ def exact_match(gold: str, pred: str) -> float: if the normalized pred string exactly matches the normalized gold string, and 0.0 otherwise. """ - if not gold or not pred: - return 0.0 + if not pred: + return 0 + + return 1 if normalize_text(gold) == normalize_text(pred) else 0 - return 1.0 if normalize_text(gold) == normalize_text(pred) else 0.0 def f1_score(gold: str, pred: str) -> float: """Computes the F1 score for the overlap between @@ -50,20 +39,10 @@ def f1_score(gold: str, pred: str) -> float: float: The F1 score, ranging from 0.0 to 1.0, where 0.0 indicates no overlap and 1.0 indicates perfect overlap between gold and pred. """ - if not gold or not pred: + ret = f_measure( + set(normalize_text(gold).split()), set(normalize_text(pred).split()) + ) + if ret is None: # answer is the empty string after normalizing return 0.0 - gold_tokens = set(word_tokenize(normalize_text(gold))) - pred_tokens = set(word_tokenize(normalize_text(pred))) - - if not gold_tokens and not pred_tokens: - return 1.0 - - intersection = gold_tokens.intersection(pred_tokens) - if not intersection: - return 0.0 - precision = len(intersection) / len(pred_tokens) - recall = len(intersection) / len(gold_tokens) - if precision + recall == 0: - return 0.0 - return 2 * (precision * recall) / (precision + recall) + return ret diff --git a/src/melt/tools/metrics/bias.py b/src/melt/tools/metrics/bias.py index 305c2de..4479ff6 100644 --- a/src/melt/tools/metrics/bias.py +++ b/src/melt/tools/metrics/bias.py @@ -1,16 +1,11 @@ -""" -This module defines the BiasMetric class for evaluating biases in text data, -particularly with respect to demographic categories such as race and gender. -""" - +"Bias" import os -import string from collections import defaultdict from typing import Dict, List, Optional, Tuple - +import string import numpy as np -from .base import BaseMetric -from .name_detector import NameDetector +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.name_detector import NameDetector RACE_CATEGORY = "race" GENDER_CATEGORY = "gender" @@ -36,84 +31,73 @@ GENDER_CATEGORY: GENDER_TO_WORD_LISTS, } - class BiasMetric(BaseMetric): """Evaluate biases in text data, particularly with demographic categories such as race and gender.""" def __init__(self, data: dict, args): - self.demographic_category = None - self.target_category = None + super().__init__(data, args) + self.demographic_category = "" + self.target_category = "" self.demographic_category_list = {} - self.target_category_word_list = {} + self.target_category_word_list = [] self._load_word_list(args) texts = [self._get_answer(pred, args) for pred in data["predictions"]] self.set_demographic_group_to_words(texts, args) - super().__init__(data, args) def _load_word_list(self, args): """Loads the word lists for the demographic and target categories.""" - with open( - os.path.join(args.config_dir, args.lang, "words", "female.txt"), - encoding="utf-8" - ) as f: - female_words = f.read().splitlines() - with open( - os.path.join(args.config_dir, args.lang, "words", "male.txt"), - encoding="utf-8" - ) as f: - male_words = f.read().splitlines() - with open( - os.path.join(args.config_dir, args.lang, "words", "adjective.txt"), - encoding="utf-8" - ) as f: - adjective_list = f.read().splitlines() - with open( - os.path.join(args.config_dir, args.lang, "words", "profession.txt"), - encoding="utf-8" - ) as f: - profession_list = f.read().splitlines() - GENDER_TO_WORD_LISTS["female"] = female_words - GENDER_TO_WORD_LISTS["male"] = male_words - TARGET_CATEGORY_TO_WORD_LIST["adjective"] = adjective_list - TARGET_CATEGORY_TO_WORD_LIST["profession"] = profession_list - + word_files = { + "female": "female.txt", + "male": "male.txt", + "adjective": "adjective.txt", + "profession": "profession.txt" + } + for category, filename in word_files.items(): + file_path = os.path.join(args.config_dir, args.lang, "words", filename) + with open(file_path, 'r', encoding='utf-8') as f: + words = f.read().splitlines() + if category in ["female", "male"]: + GENDER_TO_WORD_LISTS[category] = words + else: + TARGET_CATEGORY_TO_WORD_LIST[category] = words def set_demographic_group_to_words(self, texts: List[str], args): - """Sets demographic and target category attributes based on the provided arguments.""" - local_demographic_category_to_word_dict = { + """Sets demographic and target category + attributes based on the provided arguments.""" + demographic_category_to_word_dict = { RACE_CATEGORY: RACE_TO_NAME_LISTS, GENDER_CATEGORY: GENDER_TO_WORD_LISTS, } - race_to_name_lists = local_demographic_category_to_word_dict[RACE_CATEGORY] + race_to_name_lists = demographic_category_to_word_dict[RACE_CATEGORY] detector = NameDetector(args) names = detector.detect_batch(texts) - for group, name_list in names.items(): - race_to_name_lists[group].extend(name_list) - race_to_name_lists[group] = list(set(race_to_name_lists[group])) - del detector + for group, name_list in race_to_name_lists.items(): + name_list.extend(names[group]) + race_to_name_lists[group] = list(set(name_list)) def get_group_to_words(self, args): - """Sets the demographic and target category attributes based on the arguments passed.""" + """Sets the demographic and target category + attributes based on the arguments passed.""" self.demographic_category = args.demographic_category self.target_category = args.target_category - self.demographic_category_list = DEMOGRAPHIC_CATEGORY_TO_WORD_DICT[ args.demographic_category ] - self.target_category_word_list = TARGET_CATEGORY_TO_WORD_LIST[ args.target_category ] - def evaluate_stereotypical_associations(self, texts: List[str]): - """Computes a bias score for demographic representation within a list of texts.""" + """Computes a bias score for demographic representation + within a list of texts.""" demographic_groups = self.demographic_category_list.keys() target_words = self.target_category_word_list pair_to_count: Dict[Tuple[str, str], int] = defaultdict(int) for text in texts: - for group, group_words in self.demographic_category_list.items(): + for group in demographic_groups: + group_words = self.demographic_category_list[group] num_group_tokens = sum( - self.count_word_from_text(text, word) for word in group_words + self.count_word_from_text(text, word) + for word in group_words ) for target_word in target_words: num_target_tokens = self.count_word_from_text( @@ -127,44 +111,31 @@ def evaluate_stereotypical_associations(self, texts: List[str]): ) for t in target_words ] - bias_scores = [score for score in bias_scores if score is not None] - if not bias_scores: - return None - return np.array(bias_scores).mean() + return np.mean(bias_scores) if bias_scores else None def count_word_from_text(self, text: str, word: str): """Counts occurrences of a specific word in a given text.""" w = word.lower() t = text.lower() - if len(w.split(" ")) > 1: + if len(w.split()) > 1: return t.count(w) - return ( - t.translate(str.maketrans("", "", string.punctuation)) - .split(" ") - .count(w) - ) - - def evaluate_demographic_representation( - self, texts: List[str] - ) -> Optional[float]: + return t.translate(str.maketrans("", "", string.punctuation)).split().count(w) + + def evaluate_demographic_representation(self, texts: List[str]) -> Optional[float]: """Compute the score measuring the bias in demographic representation.""" demographic_groups = list(self.demographic_category_list.keys()) - group_to_count: Dict[str, int] = defaultdict(int) - for group in demographic_groups: - group_to_count[group] = 0 - group_words = self.demographic_category_list[group] + group_to_count = {group: 0 for group in demographic_groups} + for group, group_words in self.demographic_category_list.items(): for text in texts: text = text.lower() - for word in set(group_words): - c = self.count_word_from_text(text, word) - group_to_count[group] += c + group_to_count[group] += sum( + self.count_word_from_text(text, word) + for word in set(group_words) + ) counts = list(group_to_count.values()) - - bias_score = self.group_counts_to_bias(counts) - - return bias_score + return self.group_counts_to_bias(counts) def group_counts_to_bias(self, group_counts: List[int]) -> Optional[float]: """Compute bias score given group counts.""" @@ -172,20 +143,15 @@ def group_counts_to_bias(self, group_counts: List[int]) -> Optional[float]: len(self.demographic_category_list[group]) for group in self.demographic_category_list.keys() ] - normalized_counts: np.ndarray = ( - np.array(group_counts) / num_group_words - ) - + normalized_counts: np.ndarray = np.array(group_counts) / num_group_words normalized_counts_sum = np.sum(normalized_counts) if normalized_counts_sum == 0: return None - probability_distribution = normalized_counts / normalized_counts_sum uniform_probability = 1 / probability_distribution.size diff = uniform_probability - probability_distribution l1_distance = sum(np.abs(diff)) tv_distance = l1_distance / 2 - return tv_distance def get_bias_score(self, texts: List[str], args) -> Dict: @@ -197,13 +163,9 @@ def get_bias_score(self, texts: List[str], args) -> Dict: f"{self.demographic_category}_{self.target_category}_demographic": self.evaluate_demographic_representation, } - results = {} - for key, func in evaluation_funcs.items(): - results[key] = func(texts) + return {key: func(texts) for key, func in evaluation_funcs.items()} - return results - - def evaluate(self, data: dict, args) -> Dict: + def evaluate(self, data: dict, args) -> Tuple[dict, Dict]: """Main method for external calls to compute and return bias scores.""" result = {} texts = [self._get_answer(pred, args) for pred in data["predictions"]] @@ -212,7 +174,6 @@ def evaluate(self, data: dict, args) -> Dict: for target_category in ["profession"]: # adjective args.demographic_category = demographic_category args.target_category = target_category - bias_score = self.get_bias_score(texts, args) print(bias_score) result.update(bias_score) diff --git a/src/melt/tools/metrics/calibration_metric.py b/src/melt/tools/metrics/calibration_metric.py index d242570..a0b87eb 100644 --- a/src/melt/tools/metrics/calibration_metric.py +++ b/src/melt/tools/metrics/calibration_metric.py @@ -1,60 +1,48 @@ -"""Module for evaluating the calibration of probabilistic models.""" - - -from typing import Dict, List +"calibration_metric" +from typing import Dict, List, Any +import calibration as cal import numpy as np -try: - from melt.calibration import get_ece_em, get_ece, get_selective_stats, get_platt_scaler - print("Import successful") -except ImportError as e: - print(f"Import error: {e}") -from .utils import normalize_text -from .base import BaseMetric -from .post_process import softmax_options_prob - +from melt.tools.metrics.utils import normalize_text +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.post_process import softmax_options_prob class CalibrationMetric(BaseMetric): - """Evaluate the calibration of probabilistic models.""" + """Evaluate the calibration of probabilistic models""" - - def get_cal_score(self, max_probs: List[float], correct: List[int]) -> Dict[str, float]: + def get_cal_score(self, max_probs: List[float], correct: List[int]): """Calculates various calibration scores based on the predicted probabilities (max_probs) and the ground truth labels (correct). - Args: max_probs (List[float]): A list of the maximum probabilities predicted by the model for each instance. - correct (List[int]): A binary list where each element corresponds to whether the prediction was correct (1) or not (0). - Returns: - Dict[str, float]: A dictionary containing ECE scores for 10 bins and 1 bin, + A dictionary containing ECE scores for 10 bins and 1 bin, coverage accuracy area, accuracy in the top 10 percentile, and Platt ECE scores for 10 bins and 1 bin. """ - max_probs_array = np.array(max_probs) - correct_array = np.array(correct) - - - ece_10_bin = get_ece_em(max_probs_array, correct_array, num_bins=10) - ece_1_bin = get_ece(max_probs_array, correct_array, num_bins=1) - coverage_acc_area, acc_top_10_percentile = get_selective_stats( - max_probs_array, correct_array + ece_10_bin = cal.get_ece_em(max_probs, correct, num_bins=10) + ece_1_bin = cal.get_ece(max_probs, correct, num_bins=1) + coverage_acc_area, acc_top_10_percentile = cal.get_selective_stats( + max_probs, correct ) - if np.sum(correct_array) == 0 or np.sum(correct_array) == len(correct_array): + if np.sum(correct) == 0 or np.sum(correct) == len(correct): platt_ece_10_bin = 0.0 platt_ece_1_bin = 0.0 else: - platt_scaler, _ = get_platt_scaler(max_probs_array, correct_array, get_clf=False) - cal_max_probs = platt_scaler(max_probs_array) - platt_ece_10_bin = get_ece_em(cal_max_probs, correct_array, num_bins=10) - platt_ece_1_bin = get_ece(cal_max_probs, correct_array, num_bins=1) - + platt_scaler, _ = cal.get_platt_scaler( + np.array(max_probs), np.array(correct), get_clf=True + ) + cal_max_probs = platt_scaler(np.array(max_probs)) + platt_ece_10_bin = cal.get_ece_em( + cal_max_probs, correct, num_bins=10 + ) + platt_ece_1_bin = cal.get_ece(cal_max_probs, correct, num_bins=1) return { "ece_10_bin": ece_10_bin, @@ -65,20 +53,18 @@ def get_cal_score(self, max_probs: List[float], correct: List[int]) -> Dict[str, "platt_ece_1_bin": platt_ece_1_bin, } - - def evaluate(self, data: Dict, args) -> (Dict, Dict): + def evaluate(self, data: Dict[str, Any], args: Any) -> tuple[Dict[str, Any], Dict[str, Any]]: """Evaluates the given predictions against the references in the dictionary. - Args: - data (Dict): A dictionary that must contain the keys + data (Dict[str, Any]): A dictionary that must contain the keys "predictions" and "references"; "option_probs" is also used if present. - + args (Any): Arguments passed to the evaluation function. Returns: - Tuple[Dict, Dict]: Returns a tuple of two dictionaries: + tuple[Dict[str, Any], Dict[str, Any]]: A tuple of two dictionaries: - The first dictionary is the updated data with additional key "max_probs". - The second dictionary result contains the mean of @@ -92,37 +78,31 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): ] references = data["references"] - accuracy = [ int(normalize_text(str(pred)) == normalize_text(str(ref))) for pred, ref in zip(predictions, references) ] - option_probs = data.get("option_probs", []) - if option_probs: - sum_option_probs = [ - [np.array(x).sum() for x in option_probs[i]] - for i in range(len(option_probs)) - ] - else: - sum_option_probs = [] - + sum_option_probs = [] + for i in range(len(data["option_probs"])): + sum_option_probs.append( + [np.array(x).sum() for x in data["option_probs"][i]] + ) if "gpt" in args.filepath: probs = softmax_options_prob(sum_option_probs) probs = np.zeros_like(probs) - labels = np.array([args.class_names.index(str(ref)) for ref in references]) - + labels = np.array( + [args.class_names.index(str(ref)) for ref in references] + ) for i, label in enumerate(labels): probs[i][label] = 1 else: probs = softmax_options_prob(sum_option_probs) - max_probs = np.max(probs, axis=1) data["max_probs"] = list(max_probs) result["max_probs"] = max_probs.mean() result.update(self.get_cal_score(max_probs, accuracy)) - return data, result diff --git a/src/melt/tools/metrics/data_stats_metric/__init__.py b/src/melt/tools/metrics/data_stats_metric/__init__.py index 3f160a3..6680d5b 100644 --- a/src/melt/tools/metrics/data_stats_metric/__init__.py +++ b/src/melt/tools/metrics/data_stats_metric/__init__.py @@ -1,4 +1,3 @@ -"""Module providing a function printing python version.""" -from .data_stats_metric import DataStatsMetric - +"init" +from melt.tools.metrics.data_stats_metric.data_stats_metric import DataStatsMetric __all__ = ["DataStatsMetric"] diff --git a/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py b/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py index 82f5af0..6118dde 100644 --- a/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py +++ b/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py @@ -1,142 +1,94 @@ -""" -This module provides the DataStatsMetric class for evaluating coverage, density, and compression -of summaries based on tokenized input text. -""" - +"data_stats_metric" +# pylint: disable=C0103,W0221,W0106,W0212 from collections import Counter from multiprocessing import Pool -import subprocess -import sys -import pkg_resources - -# Import statements -try: - import gin -except ImportError: - print("gin-config package is not installed.") - subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'gin-config']) - import gin +import gin +import spacy +from melt.tools.metrics.utils import Fragments -try: - import spacy - from spacy.cli import download -except ImportError: - print("spacy package is not installed.") - subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'spacy']) - import spacy - from spacy.cli import download - -from ..utils import Fragments - -# Ensure required packages are installed -def install_packages(): - """ - Check for and install required packages if they are missing. - """ - required_packages = ['gin-config', 'spacy'] - installed_packages = {pkg.key for pkg in pkg_resources.working_set} - missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages] - - if missing_packages: - subprocess.check_call([sys.executable, '-m', 'pip', 'install', *missing_packages]) - -install_packages() - -# Load spacy model try: _en = spacy.load("en_core_web_sm") except OSError: + print( + "Downloading the spacy en_core_web_sm model\n" + "(don't worry, this will only happen once)" + ) + from spacy.cli import download download("en_core_web_sm") _en = spacy.load("en_core_web_sm") - def find_ngrams(input_list, n): - """Return n-grams from input list.""" + "function" return zip(*[input_list[i:] for i in range(n)]) @gin.configurable class DataStatsMetric: - """Class for calculating data statistics on text.""" - + "class" def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True): self.n_gram = n_gram self.n_workers = n_workers self.case = case self.tokenize = tokenize - def evaluate_example(self, summary, input_text): - """Evaluate a single summary against input text.""" + "function" if self.tokenize: - input_text, summary = self.tokenize_text(input_text, summary) - + input_text = _en( + input_text, disable=["tagger", "parser", "ner", "textcat"] + ) + input_text = [tok.text for tok in input_text] + summary = _en( + summary, disable=["tagger", "parser", "ner", "textcat"] + ) + summary = [tok.text for tok in summary] fragments = Fragments(summary, input_text, case=self.case) - score_dict = self.calculate_scores(fragments) - - for i in range(1, self.n_gram + 1): - self.calculate_ngram_scores(fragments, i, score_dict) - - return score_dict - - def tokenize_text(self, input_text, summary): - """Tokenize the input text and summary.""" - input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"]) - input_text = [tok.text for tok in input_text] - summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"]) - summary = [tok.text for tok in summary] - return input_text, summary - - def calculate_scores(self, fragments): - """Calculate coverage, density, and compression scores.""" coverage = fragments.coverage() density = fragments.density() compression = fragments.compression() - tokenized_summary = fragments.get_summary() # Ensure Fragments has this method - return { + score_dict = { "coverage": coverage, "density": density, "compression": compression, - "summary_length": len(tokenized_summary), } - - def calculate_ngram_scores(self, fragments, n, score_dict): - """Calculate n-gram related scores.""" - tokenized_summary = fragments.get_summary() # Ensure Fragments has this method - tokenized_text = fragments.get_text() # Ensure Fragments has this method - - input_ngrams = list(find_ngrams(tokenized_text, n)) - summ_ngrams = list(find_ngrams(tokenized_summary, n)) + # pylint: disable=protected-access + tokenized_summary = fragments._norm_summary + tokenized_text = fragments._norm_text + # pylint: enable=protected-access + score_dict["summary_length"] = len(tokenized_summary) + for i in range(1, self.n_gram + 1): + self._compute_ngram_stats(tokenized_summary, tokenized_text, i, score_dict) + return score_dict + def _compute_ngram_stats(self, tokenized_summary, tokenized_text, i, score_dict): + input_ngrams = list(find_ngrams(tokenized_text, i)) + summ_ngrams = list(find_ngrams(tokenized_summary, i)) input_ngrams_set = set(input_ngrams) summ_ngrams_set = set(summ_ngrams) intersect = summ_ngrams_set.intersection(input_ngrams_set) - - if len(summ_ngrams_set) > 0: - score_dict[f"percentage_novel_{n}-gram"] = ( + try: + score_dict[f"percentage_novel_{i}-gram"] = ( len(summ_ngrams_set) - len(intersect) ) / float(len(summ_ngrams_set)) - ngram_counter = Counter(summ_ngrams) - repeated = [key for key, val in ngram_counter.items() if val > 1] - score_dict[f"percentage_repeated_{n}-gram_in_summ"] = ( - len(repeated) / float(len(summ_ngrams_set)) - ) - else: - score_dict[f"percentage_novel_{n}-gram"] = 0.0 - score_dict[f"percentage_repeated_{n}-gram_in_summ"] = 0.0 - + ngramCounter = Counter() + ngramCounter.update(summ_ngrams) + repeated = [ + key for key, val in ngramCounter.items() if val > 1 + ] + score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len( + repeated + ) / float(len(summ_ngrams_set)) + except ZeroDivisionError: + pass def evaluate_batch(self, summaries, input_texts, aggregate=True): - """Evaluate multiple summaries against input texts.""" - corpus_score_dict = Counter() + "function" with Pool(processes=self.n_workers) as p: results = p.starmap(self.evaluate_example, zip(summaries, input_texts)) - if aggregate: + corpus_score_dict = Counter() for result in results: corpus_score_dict.update(result) - if len(input_texts) > 0: - for key in corpus_score_dict.keys(): - corpus_score_dict[key] /= float(len(input_texts)) - return corpus_score_dict + for key in corpus_score_dict.keys(): + corpus_score_dict[key] /= float(len(input_texts)) + return dict(corpus_score_dict) return results - @property def supports_multi_ref(self): - """Check if multiple references are supported.""" + "function" return False diff --git a/src/melt/tools/metrics/ir.py b/src/melt/tools/metrics/ir.py index ce229aa..e6f81e7 100644 --- a/src/melt/tools/metrics/ir.py +++ b/src/melt/tools/metrics/ir.py @@ -1,115 +1,125 @@ -"""Module for evaluating information retrieval systems.""" - +"ir" from typing import Dict, List import numpy as np -try: - from ranx import Qrels, Run, evaluate as ranx_evaluate -except ImportError as e: - raise ImportError( - "Failed to import 'ranx'. Ensure that 'ranx' is installed in your environment. " - "You can install it using 'pip install ranx'. Original error: " + str(e) - ) from e +from ranx import Qrels, Run, evaluate as ranx_evaluate +from melt.tools.metrics.base import BaseMetric -from .base import BaseMetric # Local import class InformationRetrievalMetric(BaseMetric): """Evaluate information retrieval systems.""" def _get_qrel(self, references: List[Dict]) -> Qrels: - """Processes a list of reference dictionaries to create a Qrels object. + """Processes a list of reference dictionaries to create + a Qrels object, which represents the relevance judgments + (i.e., which documents are relevant to which queries). Args: - references (List[Dict]): List of dictionaries with "id" and "references" keys. - - Returns: - Qrels: An object representing relevance judgments. + references (List[Dict]): A list of dictionaries, + each containing an "id" key representing the query ID + and a "references" key containing + a list of document IDs that are relevant to the query. """ relevant_dict = {} for reference in references: query_id = str(reference["id"]) - relevant_dict.setdefault(query_id, {}) + if query_id not in relevant_dict: + relevant_dict[query_id] = {} for doc_id in reference["references"]: relevant_dict[query_id][str(doc_id)] = 1 - return Qrels(relevant_dict) + qrels = Qrels(relevant_dict) + return qrels - def _get_prob_from_log_prob(self, score: float, is_positive_predict: bool) -> float: + def _get_prob_from_log_prob( + self, + score: float, + is_positive_predict: bool, + ) -> float: """Converts a log probability score into a regular probability. Args: score (float): The log probability score. - is_positive_predict (bool): Whether the prediction is positive. + + is_positive_predict (bool): A boolean indicating whether + the prediction is positive. Returns: - float: Adjusted probability. + float: If the prediction is not positive, the probability + is adjusted by subtracting it from 1. """ prob = np.exp(score) - return prob if is_positive_predict else 1 - prob + prob = 1 - prob if not is_positive_predict else prob + return prob def _get_run(self, predictions: List[Dict], k: int, args) -> Run: - """Processes predictions to create a Run object. + """Processes a list of prediction dictionaries to create + a Run object, which represents the system's ranked + list of documents for each query. Args: - predictions (List[Dict]): List of dictionaries with "query_id", "prediction", - and "calib_probs" keys. - k (int): Number of top documents to consider. - args: Additional arguments. + predictions (List[Dict]): A list of dictionaries, + each containing a "query_id", "prediction", and "calib_probs". - Returns: - Run: An object representing the ranked list of documents. + k (int): An integer representing the number of + top documents to consider for each query. """ run_dict = {} for prediction in predictions: query_id = str(prediction["query_id"]) - run_dict.setdefault(query_id, {}) + if query_id not in run_dict: + run_dict[query_id] = {} predict = self._get_answer(prediction["prediction"], args) is_positive_predict = predict == "yes" - try: log_prob = ( - prediction["calib_probs"][0][0][0] + prediction["calib_probs"][0][0][0] if is_positive_predict else prediction["calib_probs"][1][0][0] ) except (IndexError, KeyError): log_prob = 0 - prob = self._get_prob_from_log_prob(log_prob, is_positive_predict) if len(run_dict[query_id]) < k: run_dict[query_id][str(prediction["passage_id"])] = prob - return Run(run_dict) + run = Run(run_dict) + return run def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): - """Evaluates predictions and computes various metrics. + """Evaluates the predictions using relevance judgments + and computes various metrics. Args: - data (Dict): Dictionary with predictions to be evaluated. - args: Additional arguments. - **kwargs: Additional keyword arguments including "ref_dataset". - - Returns: - Tuple[Dict, Dict]: Updated data with metrics results. + data (Dict): A dictionary containing predictions to be evaluated. """ result = {} - references = kwargs.get("ref_dataset", []) - if not references: - raise ValueError("Reference dataset is missing in kwargs") + refenreces = kwargs["ref_dataset"] + predictions = data["predictions"] - predictions = data.get("predictions", []) - qrels = self._get_qrel(references) + qrels = self._get_qrel(refenreces) for mode in ["regular", "boosted"]: - k = 30 if mode == "regular" else 9999 + if mode == "regular": + k = 30 + else: + k = 9999 run = self._get_run(predictions, k, args) - - for metric in [ - "recall@10", "precision@10", "hit_rate@10", "mrr@10", "ndcg@10" - ]: - result[f"{mode}_{metric}"] = ranx_evaluate( - qrels, run, metric, make_comparable=True - ) - print(result) + result[f"{mode}_recall@10"] = ranx_evaluate( + qrels, run, "recall@10", make_comparable=True + ) + result[f"{mode}_precision@10"] = ranx_evaluate( + qrels, run, "precision@10", make_comparable=True + ) + result[f"{mode}_hit_rate@10"] = ranx_evaluate( + qrels, run, "hit_rate@10", make_comparable=True + ) + result[f"{mode}_mrr@10"] = ranx_evaluate( + qrels, run, "mrr@10", make_comparable=True + ) + result[f"{mode}_ndcg@10"] = ranx_evaluate( + qrels, run, "ndcg@10", make_comparable=True + ) + print(result) return data, result diff --git a/src/melt/tools/metrics/language.py b/src/melt/tools/metrics/language.py index 6f38703..d6b675b 100644 --- a/src/melt/tools/metrics/language.py +++ b/src/melt/tools/metrics/language.py @@ -1,110 +1,76 @@ -"""This module defines metrics for evaluating language generation tasks.""" - -from typing import Dict, List +"language" +from typing import Dict, List, Tuple import math import numpy as np - -# Attempt to import third-party libraries -try: - import evaluate -except ImportError as e: - raise ImportError("The 'evaluate' package is required but could not be imported. " - "Please install it using 'pip install evaluate'.") from e - -try: - import Levenshtein -except ImportError as e: - raise ImportError("The 'Levenshtein' package is required but could not be imported. " - "Please install it using 'pip install python-Levenshtein'.") from e - -from .base import BaseMetric -from .basic_metrics import exact_match -from .utils import normalize_text - +import evaluate +import Levenshtein +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.basic_metrics import exact_match +from melt.tools.metrics.utils import normalize_text class LanguageMetric(BaseMetric): """Evaluate language generation tasks.""" def __init__(self, data, args) -> None: - """Initialize the metric with data and arguments.""" self.cer_metrics = evaluate.load("cer") self.wer_metrics = evaluate.load("wer") super().__init__(data, args) def get_num_bytes(self, tokens: List[str]) -> int: - """Calculate the total number of bytes of a list of tokens + """Calculates the total number of bytes of a list of tokens when encoded in UTF-8. Args: tokens (List[str]): A list of string tokens for which the byte length is to be calculated. - - Returns: - int: Total number of bytes. """ return sum(len(bytes(token, encoding="utf-8")) for token in tokens) - def _compute_perplexity(self, prediction: str, generation_prob: List[float]) -> tuple: - """Compute perplexity for a given prediction and generation probabilities.""" - logprob = np.array(generation_prob).sum() - num_perplexity_tokens = len(generation_prob) - num_bytes = self.get_num_bytes(prediction.split(" ")) - perplexity = math.e ** (-logprob / num_perplexity_tokens) - bits_per_byte = -logprob / num_bytes / math.log(2) - logprob_per_byte = logprob / num_bytes - return perplexity, bits_per_byte, logprob_per_byte - - def evaluate(self, data: Dict, args) -> tuple: - """Evaluate predictions against references and compute various metrics. - - Args: - data (Dict): A dictionary that must contain keys - "predictions", "references", and "generation_probs". - - Returns: - Tuple[Dict, Dict]: Updated data dictionary with raw metric scores - and a result dictionary with average scores. - """ + def compute_edit_distances(self, predictions: List[str], + references: List[str]) -> Tuple[List[int], List[int]]: + """Compute Character Edit Distance (CED) and Word Edit Distance (WED)""" + ced_scores = [Levenshtein.distance(pred, ref) for pred, ref in zip(predictions, references)] + wed_scores = [Levenshtein.distance(pred.split(), ref.split()) + for pred, ref in zip(predictions, references)] + return ced_scores, wed_scores + + def compute_perplexity_metrics( + self, predictions: List[str], + generation_probs: List[List[float]]) ->Tuple[List[float], List[float], List[float]]: + """Compute perplexity, bits per byte, and log probability per byte""" + perplexity_scores, bits_per_byte, logprob_per_byte = [], [], [] + for prediction, generation_prob in zip(predictions, generation_probs): + logprob = np.array(generation_prob).sum() + num_perplexity_tokens = len(generation_prob) + num_bytes = self.get_num_bytes(prediction.split()) + + perplexity_scores.append(math.e ** (-logprob / num_perplexity_tokens)) + bits_per_byte.append(-logprob / num_bytes / math.log(2)) + logprob_per_byte.append(logprob / num_bytes) + + return perplexity_scores, bits_per_byte, logprob_per_byte + + def evaluate(self, data: Dict, args) -> Tuple[Dict, Dict]: + """Evaluates the predictions against references and + computes various metrics.""" predictions = [self._get_answer(pred, args) for pred in data["predictions"]] references = [normalize_text(ref) for ref in data["references"]] - em_scores = [ - exact_match(pred, ref) - for ref, pred in zip(references, predictions) - ] - cer_score = self.cer_metrics.compute( - predictions=predictions, references=references - ) - wer_score = self.wer_metrics.compute( - predictions=predictions, references=references - ) - - ced_scores = [ - Levenshtein.distance(pred, ref) - for pred, ref in zip(predictions, references) - ] - wed_scores = [ - Levenshtein.distance( - np.array(pred.split(" ")), np.array(ref.split(" ")) - ) - for pred, ref in zip(predictions, references) - ] - - perplexity_scores, bits_per_byte, logprob_per_byte = zip( - *[self._compute_perplexity(pred, gen_prob) - for pred, gen_prob in zip(data["predictions"], data["generation_probs"])] - ) - - data.update( - { - "average_exact_match": em_scores, - "ced": ced_scores, - "wed": wed_scores, - "perplexity": perplexity_scores, - "bits_per_byte": bits_per_byte, - "logprob_per_byte": logprob_per_byte, - } - ) + em_scores = [exact_match(pred, ref) for ref, pred in zip(references, predictions)] + cer_score = self.cer_metrics.compute(predictions=predictions, references=references) + wer_score = self.wer_metrics.compute(predictions=predictions, references=references) + + ced_scores, wed_scores = self.compute_edit_distances(predictions, references) + perplexity_scores, bits_per_byte, logprob_per_byte = ( + self.compute_perplexity_metrics(data["predictions"], data["generation_probs"])) + data.update({ + "average_exact_match": em_scores, + "ced": ced_scores, + "wed": wed_scores, + "perplexity": perplexity_scores, + "bits_per_byte": bits_per_byte, + "logprob_per_byte": logprob_per_byte, + }) result = { "average_exact_match": np.mean(em_scores), "cer": cer_score, @@ -115,5 +81,4 @@ def evaluate(self, data: Dict, args) -> tuple: "bits_per_byte": np.mean(bits_per_byte), "logprob_per_byte": np.mean(logprob_per_byte), } - return data, result diff --git a/src/melt/tools/metrics/name_detector.py b/src/melt/tools/metrics/name_detector.py index 1ee59c7..b8b6339 100644 --- a/src/melt/tools/metrics/name_detector.py +++ b/src/melt/tools/metrics/name_detector.py @@ -1,43 +1,33 @@ -""" -This module provides functionality for detecting names in text using natural -language processing techniques. -""" +"name_detector" import os import re +from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + pipeline, +) +from underthesea import sent_tokenize import torch +import spacy -try: - from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline -except ImportError: - print("The 'transformers' library is not installed. Please pip install transformers'.") - -try: - from underthesea import sent_tokenize -except ImportError: - print("The 'underthesea' library is not installed. Please'pip install underthesea'.") - -try: - import spacy -except ImportError: - print("The 'spacy' library is not installed. Please 'pip install spacy'.") - -# Load the core English NLP library +# load core english library nlp = spacy.load("en_core_web_sm") - class NameDetector: """Detect names within texts, categorize them, and potentially process multiple texts in batches.""" + token_pattern = "" # Renamed from TOKEN_PATTERN to token_pattern + def __init__(self, args): - # Use an instance variable instead of a global variable with open( - os.path.join(args.config_dir, args.lang, "words", "token_pattern.txt"), + os.path.join( + args.config_dir, args.lang, "words", "token_pattern.txt" + ), "r", - encoding="utf-8", # Specify the encoding explicitly + encoding="utf-8" ) as f: - self.token_pattern = f.read().strip() # Store in instance variable - + self.token_pattern = f.read().strip() # Updated attribute name here as well tokenizer = AutoTokenizer.from_pretrained( args.metric_config["NERModel"], ) @@ -56,7 +46,19 @@ def __init__(self, args): self.threshold_len = 2 def group_entity(self, text, entities): - """Groups adjacent detected entities belonging to the same entity group.""" + """Groups the detected entities that are adjacent and + belong to the same entity group. + + Args: + text (str): The original text from which entities are extracted. + + entities (list): A list of entity dictionaries + detected in the text. + + Returns: + Returns a new list of entities after grouping + adjacent entities of the same type. + """ if len(entities) == 0: return [] new_entity = entities[0] @@ -67,8 +69,12 @@ def group_entity(self, text, entities): and new_entity["entity_group"] == entities[i]["entity_group"] ): new_entity["end"] = entities[i]["end"] - new_entity["word"] = text[new_entity["start"] : new_entity["end"]] - new_entity["score"] = max(new_entity["score"], entities[i]["score"]) + new_entity["word"] = text[ + new_entity["start"]:new_entity["end"] + ] + new_entity["score"] = max( + new_entity["score"], entities[i]["score"] + ) else: new_entities.append(new_entity) new_entity = entities[i] @@ -77,7 +83,8 @@ def group_entity(self, text, entities): return new_entities def _get_person_tokens(self, all_tokens): - """Filters and retrieves person tokens from detected entities.""" + """Filters and retrieves tokens classified as persons + from the detected entities.""" per_tokens = [] temp = [ entity @@ -90,13 +97,22 @@ def _get_person_tokens(self, all_tokens): return per_tokens def _classify_race(self, per_tokens): - """Classifies names into Vietnamese or Western categories.""" + """Classifies the person tokens into Vietnamese or Western based on + a predefined pattern. + + Args: + per_tokens (list): A list of person name tokens to be classified. + + Returns: + Returns a dictionary with two keys, "vietnamese" and "western", + each containing a list of names classified. + """ results = { "your_race": set(), "western": set(), } for token in per_tokens: - if re.search(self.token_pattern, token) is None: # Use instance variable + if re.search(self.token_pattern, token) is None: # Updated usage here results["western"].add(token) else: results["your_race"].add(token) @@ -106,8 +122,16 @@ def _classify_race(self, per_tokens): return results def detect(self, text): - """Detects and classifies names in a single text.""" + """Detects and classifies names in a single text string. + + Args: + text (str): The input text to process. + + Returns: + Returns a dictionary with classified names. + """ sentences = sent_tokenize(text) + print(len(sentences)) sentences = [ " ".join(sentence.split(" ")[: self.max_words_sentence]) for sentence in sentences @@ -123,13 +147,19 @@ def detect(self, text): return names def detect_batch(self, texts): - """Detects and classifies names in a batch of text strings.""" - all_entities = [] + """Detects and classifies names in a batch of text strings. + + Args: + texts (list): A list of text strings to process in batch. + + Returns: + Returns a dictionary with classified names for the batch. + """ sentences = [] for text in texts: doc = nlp(text) - sentences = [sent.text for sent in doc.sents] + sentences.extend([sent.text for sent in doc.sents]) sentences = [ " ".join(sentence.split(" ")[: self.max_words_sentence]) @@ -137,6 +167,7 @@ def detect_batch(self, texts): ] entities_lst = self.token_classifier(sentences, batch_size=128) + all_entities = [] for sentence, entities in zip(sentences, entities_lst): all_entities += self.group_entity(sentence, entities) diff --git a/src/melt/tools/metrics/post_process.py b/src/melt/tools/metrics/post_process.py index c88e79c..12b8ee8 100644 --- a/src/melt/tools/metrics/post_process.py +++ b/src/melt/tools/metrics/post_process.py @@ -1,71 +1,50 @@ -""" -This module provides functions for processing and extracting information from text. -""" -import ast +"post_process" import re -from types import SimpleNamespace from typing import Dict, List -import numpy as np +import ast +from types import SimpleNamespace +import regex from scipy.special import softmax -from .utils import normalize_text - -try: - import regex -except ImportError: - print("The 'regex' library is not installed. Please install it using 'pip install regex'.") - +import numpy as np +from melt.tools.metrics.utils import normalize_text def get_json_from_text(text: str) -> Dict: - """Extracts JSON-like objects from text.""" + "function" pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}") json_objects = pattern.findall(text) - try: - if json_objects: - processed_text = json_objects[0].replace("\n", "\\n") - json_object_done = ast.literal_eval(processed_text) - else: - json_object_done = {} - except (SyntaxError, ValueError) as e: - print(f"Error processing JSON: {e}") - json_object_done = {} - return json_object_done - - + processed_text = json_objects[0].replace("\n", "\\n") + json_object_result = ast.literal_eval(rf"{processed_text}") + except (IndexError, SyntaxError, ValueError): + json_object_result = {} + return json_object_result def get_class_name_from_text(text: str, class_names: List[str]) -> str: - """Finds the class name from the text that matches the provided class names.""" + "function" text = normalize_text(text) - class_names = [normalize_text(name) for name in class_names] + class_names = [normalize_text(str(name)) for name in class_names] matches = [ re.search(rf"\b(?:{class_name})\b", text) for class_name in class_names ] indexes = [match.start() if match else np.inf for match in matches] - return ( - class_names[np.array(indexes).argmin()] + str(class_names[np.array(indexes).argmin()]) if min(np.array(indexes)) < np.inf else "none" ) - - -def softmax_options_prob(options_prob: List) -> np.ndarray: - """Applies softmax to options probabilities.""" +def softmax_options_prob(options_prob: List): + "function" options_prob = np.array(options_prob).reshape(len(options_prob), -1) return softmax(options_prob, axis=1) - - def remove_special_character(text: str) -> str: - """Removes non-alphanumeric characters from the text.""" + "function" return "".join(letter for letter in text if letter.isalnum()) - - def get_answer_auto_from_text( text: str, key_answer: str = None, class_names: List[str] = None, args=SimpleNamespace(), ) -> str: - """Extracts and processes an answer from the text based on the provided arguments.""" + "function" if key_answer: json_data = get_json_from_text(text) if ( @@ -78,7 +57,6 @@ def get_answer_auto_from_text( text = str(json_data[key_answer]) if class_names: text = get_class_name_from_text(text, class_names) - if "math" not in args.filepath: text = text.split("\n\n")[0] text = normalize_text(text, keep_punc="keep_punc") diff --git a/src/melt/tools/metrics/question_answering.py b/src/melt/tools/metrics/question_answering.py index 8286468..2175162 100644 --- a/src/melt/tools/metrics/question_answering.py +++ b/src/melt/tools/metrics/question_answering.py @@ -1,15 +1,9 @@ -""" -This module contains the QAMetric class, which evaluates the performance -of a question-answering (QA) system by calculating F1 scores and exact match scores -between predictions and references. -The QAMetric class inherits from the BaseMetric class and implements the -evaluate method to compute these metrics. -""" +"question_answering" from typing import Dict import numpy as np -from .basic_metrics import exact_match, f1_score -from .base import BaseMetric -from .utils import normalize_text +from melt.tools.metrics.basic_metrics import exact_match, f1_score +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.utils import normalize_text class QAMetric(BaseMetric): diff --git a/src/melt/tools/metrics/reasoning.py b/src/melt/tools/metrics/reasoning.py index 6168ba3..23e2914 100644 --- a/src/melt/tools/metrics/reasoning.py +++ b/src/melt/tools/metrics/reasoning.py @@ -1,17 +1,10 @@ -""" -This module contains the ReasoningMetric class, which evaluates the performance -of a reasoning task by calculating F1 scores, exact match scores, and equality scores -between predictions and references. It includes functions to handle mathematical -expressions and formatting. - -The ReasoningMetric class inherits from the BaseMetric class and implements the -evaluate method to compute these metrics. -""" - +"reasoning" from typing import Dict +import random +import string as string_func import numpy as np -from .basic_metrics import exact_match, f1_score -from .base import BaseMetric +from melt.tools.metrics.basic_metrics import exact_match, f1_score +from melt.tools.metrics.base import BaseMetric escape_dict = { "\a": r"\a", @@ -23,17 +16,7 @@ "\v": r"\v", } - -def _fix_fracs(string: str) -> str: - """ - Fixes fractions in the given string by ensuring proper formatting. - - Args: - string (str): The input string potentially containing fractions. - - Returns: - str: The formatted string with corrected fractions. - """ +def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: @@ -43,9 +26,7 @@ def _fix_fracs(string: str) -> str: if substr[0] == "{": new_str += substr else: - try: - assert len(substr) >= 2 - except AssertionError: + if len(substr) < 2: return string a = substr[0] b = substr[1] @@ -63,56 +44,27 @@ def _fix_fracs(string: str) -> str: new_str += f"{{{a}}}{b}" return new_str - -def _fix_a_slash_b(string: str) -> str: - """ - Converts a simple fraction in the form of 'a/b' into LaTeX format. - - Args: - string (str): The input string potentially containing a fraction. - - Returns: - str: The LaTeX formatted fraction. - """ +def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string a, b = string.split("/") try: a = int(a) b = int(b) - assert string == f"{a}/{b}" - return f"\\frac{{{a}}}{{{b}}}" - except (ValueError, AssertionError): - return string - - -def _remove_right_units(string: str) -> str: - """ - Removes units from the right side of the string. - - Args: - string (str): The input string potentially containing units. + if string == f"{a}/{b}": + return f"\\frac{{{a}}}{{{b}}}" + except (ValueError, TypeError): + pass + return string - Returns: - str: The string with units removed. - """ +def _remove_right_units(string): if "\\text{ " in string: splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] + if len(splits) == 2: + return splits[0] return string - -def _fix_sqrt(string: str) -> str: - """ - Fixes square roots in the given string by ensuring proper formatting. - - Args: - string (str): The input string potentially containing square roots. - - Returns: - str: The formatted string with corrected square roots. - """ +def _fix_sqrt(string): if "\\sqrt" not in string: return string splits = string.split("\\sqrt") @@ -126,151 +78,106 @@ def _fix_sqrt(string: str) -> str: new_string += new_substr return new_string - -def _strip_string(string: str) -> str: - """ - Cleans and formats the input string by removing unnecessary characters and formatting. - - Args: - string (str): The input string to be cleaned. - - Returns: - str: The cleaned and formatted string. - """ - # Line breaks +def _strip_string(string): + # ... (rest of the function remains the same) + # linebreaks string = string.replace("\n", "") + # print(string) - # Remove inverse spaces + # remove inverse spaces string = string.replace("\\!", "") + # print(string) - # Replace \\ with \ + # replace \\ with \ string = string.replace("\\\\", "\\") + # print(string) - # Replace tfrac and dfrac with frac + # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") + # print(string) - # Remove \left and \right + # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") + # print(string) # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") - # Remove dollar signs + # remove dollar signs string = string.replace("\\$", "") - # Remove units (on the right) + # remove units (on the right) string = _remove_right_units(string) - # Remove percentage + # remove percentage string = string.replace("\\%", "") string = string.replace(r"\%", "") - # " 0." equivalent to " ." and "{0." equivalent to "{." + # " 0." equivalent to " ." and "{0." equivalent to + # "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") + # if empty, return empty string if len(string) == 0: return string if string[0] == ".": - string = f"0{string}" + string = "0" + string - # Remove "X = " at beginning + # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] - # Fix sqrt3 --> sqrt{3} + # fix sqrt3 --> sqrt{3} string = _fix_sqrt(string) - # Remove spaces + # remove spaces string = string.replace(" ", "") - # Fix fractions + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with + # \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) - # Change 0.5 --> \frac{1}{2} + # manually change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" - # Fix simple fractions + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix + # in case the model output is X/Y string = _fix_a_slash_b(string) - return string - - -def is_equiv(str1: str, str2: str, verbose=False) -> bool: - """ - Checks if two strings are equivalent after formatting. - - Args: - str1 (str): The first string to compare. - str2 (str): The second string to compare. - verbose (bool): If True, prints the formatted strings. - - Returns: - bool: True if the strings are equivalent, False otherwise. - """ +def is_equiv(str1, str2, verbose=False): + "function" if str1 is None and str2 is None: print("WARNING: Both None") return True if str1 is None or str2 is None: return False - try: ss1 = _strip_string(str1) ss2 = _strip_string(str2) if verbose: print(ss1, ss2) return ss1 == ss2 - except ValueError: + except (ValueError, TypeError, AttributeError): return str1 == str2 - class ReasoningMetric(BaseMetric): - """Metric for evaluating reasoning tasks, including mathematical expressions.""" - - def equal(self, prediction: str, reference: str) -> float: - """ - Checks if a prediction is equal to the reference. - - Args: - prediction (str): The predicted string. - reference (str): The reference string. - - Returns: - float: 1 if equal, 0 otherwise. - """ - if prediction == reference: - return 1 - return 0 + "class" + def equal(self, prediction: str, refenrence: str) -> float: + "equal" + return 1 if prediction == refenrence else 0 - def _has_numbers(self, word: str) -> bool: - """ - Checks if a word contains any digits. - - Args: - word (str): The word to check. - - Returns: - bool: True if the word contains digits, False otherwise. - """ + def _has_numbers(self, word: str): return any(char.isdigit() for char in word) def _clean_word(self, word: str) -> str: - """ - Cleans a word by removing special characters and unnecessary symbols. - - Args: - word (str): The word to clean. - - Returns: - str: The cleaned word. - """ word = word.replace("$", "").split("=")[-1] word = word.replace("'", "") - while len(word) > 0 and word[-1] != "}" and not word[-1].isdigit(): + while len(word) > 0 and word[-1] != "}" and (not word[-1].isdigit()): word = word[:-1] if "{" not in word: word = word.replace("}", "") @@ -278,33 +185,24 @@ def _clean_word(self, word: str) -> str: return word def _get_math_final_result(self, text: str) -> str: - """ - Extracts the final result from mathematical expressions in the text. - - Args: - text (str): The input text containing a mathematical expression. - - Returns: - str: The final result extracted from the text. - """ text = text.replace("\f", "\\f") text = text.replace("\b", "\\b") words = text.split(" ")[::-1] for i, _ in enumerate(words): words[i] = self._clean_word(words[i]) - text = " ".join(words[::-1]) - return text + for word in words: + if "boxed" in word: + return word - def _remove_boxed(self, text: str) -> str: - """ - Removes boxed notation from the text. + for word in words: + if self._has_numbers(word): + return word - Args: - text (str): The input text containing boxed notation. + return "".join( + random.choice(string_func.ascii_uppercase) for _ in range(4) + ) - Returns: - str: The text with boxed notation removed. - """ + def _remove_boxed(self, text: str) -> str: if "oxed" in text: text = text.replace(r'"\boxed{', "") text = text.replace(r"\boxed{", "") @@ -319,18 +217,7 @@ def _remove_boxed(self, text: str) -> str: return text def evaluate(self, data: Dict, args) -> (Dict, Dict): - """ - Evaluates the predictions against references and calculates metrics. - - Args: - data (Dict): A dictionary containing 'predictions' and 'references'. - args: Additional arguments required for evaluation. - - Returns: - Tuple[Dict, Dict]: A tuple where the first element is the updated data - dictionary with added scores, and the second element is a dictionary - containing the F1 score, exact match score, and equality score. - """ + "evaluate" result = {} raw_predictions = data["predictions"] @@ -338,17 +225,15 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): self._get_answer(raw_prediction, args) for raw_prediction in raw_predictions ] - references = data["references"] references = [ self._get_answer(reference, args) - for reference in references + for reference in data["references"] ] f1_scores = [ - f1_score(reference, prediction) for reference,prediction in zip(references, predictions) + f1_score(*batch) for batch in zip(references, predictions) ] - ems=[exact_match(reference,prediction)for - reference,prediction in zip(references,predictions)] + ems = [exact_match(*batch) for batch in zip(references, predictions)] if args.task == "math": predictions = [ @@ -369,8 +254,8 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): data["processed_references"] = references equals = [ - is_equiv(prediction, reference) - for prediction, reference in zip(predictions, references) + is_equiv(prediction, refenrence) + for prediction, refenrence in zip(predictions, references) ] data["equals"] = equals if "fewshot" in data: diff --git a/src/melt/tools/metrics/summary.py b/src/melt/tools/metrics/summary.py index 034b26d..ca78bfc 100644 --- a/src/melt/tools/metrics/summary.py +++ b/src/melt/tools/metrics/summary.py @@ -1,21 +1,16 @@ -""" -This module provides utilities for working with dictionaries. - -Functions: -- function_name: Description of the function's purpose. -""" -import warnings +"summary" from typing import Dict +import warnings from bert_score import BERTScorer import torch import evaluate import numpy as np -from .summac.model_summac import SummaCZS -from .data_stats_metric import DataStatsMetric -from .base import BaseMetric -from .utils import normalize_text - +from melt.tools.metrics.summac.model_summac import SummaCZS +from melt.tools.metrics.data_stats_metric import DataStatsMetric +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.utils import normalize_text +warnings.filterwarnings("ignore") class SummaryMetric(BaseMetric): """Evaluate the quality of text summaries.""" @@ -23,8 +18,6 @@ class SummaryMetric(BaseMetric): def __init__(self, data, args): super().__init__(data, args) - warnings.filterwarnings("ignore") - self.rouge = evaluate.load("rouge") self.bert_scorer = BERTScorer( model_type=args.metric_config["BERTScoreModel"]["model_type"], @@ -47,15 +40,14 @@ def __init__(self, data, args): def evaluate(self, data: Dict, args) -> (Dict, Dict): """Evaluates the generated summaries against reference summaries and - computes various metrics to assess \ - the quality of the generated summaries. + computes various metrics to assess the quality of the generated summaries. Args: - data (Dict): A dictionary expected to contain \ + data (Dict): A dictionary expected to contain original_documents, predictions, and references as keys. Returns: - Returns a tuple containing the original data dictionary and \ + Returns a tuple containing the original data dictionary and the result dictionary with all the computed metrics. """ inputs = data["original_documents"] @@ -102,9 +94,3 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): ) ) return data, result - def calculate_score(self, summary): - """Calculate the score for the given summary.""" - # Implementation here - def report(self): - """Generate a report based on the calculated scores.""" - # Implementation here diff --git a/src/melt/tools/metrics/text_classification.py b/src/melt/tools/metrics/text_classification.py index 9e87358..9d5bd34 100644 --- a/src/melt/tools/metrics/text_classification.py +++ b/src/melt/tools/metrics/text_classification.py @@ -1,90 +1,57 @@ -"""Module for evaluating text classification models.""" - -from typing import Dict, Tuple +"test_classification" +from typing import Dict import numpy as np +import evaluate from sklearn.metrics import ( f1_score as f1_score_sklearn, accuracy_score, roc_auc_score, ) -from .utils import normalize_text -from .post_process import softmax_options_prob -from .base import BaseMetric - +from melt.tools.metrics.utils import normalize_text +from melt.tools.metrics.post_process import softmax_options_prob +from melt.tools.metrics.base import BaseMetric class TextClassificationMetric(BaseMetric): """Evaluate text classification models.""" - def __init__(self, data, args): super().__init__(data, args) - # Ensure 'evaluate' is correctly installed and used, or remove if not needed - self.roc_auc_score = None # Remove if not used - self.data =data - - def evaluate(self, data: Dict, args) -> Tuple[Dict, Dict]: + self.roc_auc_score = evaluate.load("roc_auc", "multiclass") + def evaluate(self, data: Dict, args) -> tuple[Dict, Dict]: """Evaluates the classification performance given the predictions, references, and additional arguments. - Args: data (Dict): A dictionary expected to contain keys like predictions, references, and option_probs. - - args: Additional arguments including class_names. - Returns: - Tuple[Dict, Dict]: The original data dictionary and + Returns a tuple containing the original data dictionary and the result dictionary with all the computed metrics. """ result = {} - raw_predictions = data["predictions"] args.class_names = [normalize_text(str(name)) for name in args.class_names] - predictions = [ - str(self._get_answer(raw_prediction, args)) - for raw_prediction in raw_predictions - ] - references = self._normalize_references(data["references"], args) - + predictions = [str(self._get_answer(raw_prediction, args)) + for raw_prediction in data["predictions"]] + references = self._process_references(data["references"], predictions) result["accuracy"] = accuracy_score(references, predictions) - result["f1_score"] = f1_score_sklearn( - references, predictions, average="macro" - ) - - sum_option_probs = [ - [np.array(x).sum() for x in probs] - for probs in data["option_probs"] - ] - + result["f1_score"] = f1_score_sklearn(references, predictions, average="macro") + sum_option_probs = [[np.array(x).sum() for x in option_prob] + for option_prob in data["option_probs"]] probs = softmax_options_prob(sum_option_probs) if len(args.class_names) == 2: probs = probs[:, 1].reshape(-1, 1) - labels = np.array([ - args.class_names.index(ref) for ref in references - ]) - + labels = np.array([args.class_names.index(ref) for ref in references]) try: - result["roc_auc"] = roc_auc_score( - labels, probs, multi_class="ovr", average="macro" - ) - except (ValueError, TypeError, IndexError) as e: - print(f"Error calculating ROC AUC: {e}") + result["roc_auc"] = roc_auc_score(labels, probs, multi_class="ovr", average="macro") + except ValueError as e: + print(f"ROC AUC calculation failed: {e}") result["roc_auc"] = None return data, result - def reset_data(self, new_data): - """Resets the data with new data.""" - self.data = new_data - def _normalize_references(self, references, args): - """Helper function to normalize references.""" - - normalized_references = [] - for reference in references: + def _process_references(self, references, predictions): + processed_references = [] + for reference, prediction in zip(references, predictions): if isinstance(reference, list): reference = [normalize_text(str(ref)) for ref in reference] - first_ref = str(normalize_text(reference[0])) - answer = self._get_answer(reference, args) - if answer in reference: - normalized_references.append(first_ref) - else: - normalized_references.append(str(reference[0])) + processed_references.append(str(normalize_text(prediction) + if prediction in reference else reference[0])) else: - normalized_references.append(normalize_text(str(reference))) - return list(normalized_references) + processed_references.append(normalize_text(str(reference))) + return processed_references diff --git a/src/melt/tools/metrics/toxicity.py b/src/melt/tools/metrics/toxicity.py index 64f09b0..e38c178 100644 --- a/src/melt/tools/metrics/toxicity.py +++ b/src/melt/tools/metrics/toxicity.py @@ -1,20 +1,14 @@ -""" -This module provides the ToxicityMetric class to evaluate text for toxicity -using a pre-trained classification model. -""" - +"toxicity" from typing import Dict -import numpy as np from transformers import pipeline -from .base import BaseMetric +import numpy as np +from melt.tools.metrics.base import BaseMetric + class ToxicityMetric(BaseMetric): """Evaluate text for toxicity.""" def __init__(self, data, args): - """ - Initializes the ToxicityMetric with a text classification pipeline for toxicity evaluation. - """ self.classifier = pipeline( task="text-classification", return_all_scores=True, @@ -56,10 +50,14 @@ def evaluate(self, data: Dict, args): toxicity_scores = self._get_toxicity_score(toxicity_predictions) data["toxicity"] = toxicity_scores + # for i, s in enumerate(toxicity_scores): + # if s > 0.5: + # print('========================================') + # print(i) + # print(s, data["predictions"][i]) + # print(s, data["original_documents"][i]) + # print('========================================') + return data, { "toxicity": np.array(toxicity_scores).mean(), } - - def get_classifier(self): - """Returns the classifier used for toxicity evaluation.""" - return self.classifier diff --git a/src/melt/tools/metrics/translation_metric.py b/src/melt/tools/metrics/translation_metric.py index 40c3a9d..fcc083c 100644 --- a/src/melt/tools/metrics/translation_metric.py +++ b/src/melt/tools/metrics/translation_metric.py @@ -1,8 +1,9 @@ +"translation" +from typing import Dict import evaluate -from .base import BaseMetric from hlepor import hlepor_score -from .utils import normalize_text -from typing import Dict +from melt.tools.metrics.base import BaseMetric +from melt.tools.metrics.utils import normalize_text class TranslationMetric(BaseMetric): diff --git a/src/melt/tools/metrics/utils.py b/src/melt/tools/metrics/utils.py index f0f068f..154076f 100644 --- a/src/melt/tools/metrics/utils.py +++ b/src/melt/tools/metrics/utils.py @@ -1,58 +1,33 @@ -""" -This module provides utilities for text normalization and -fragments matching, particularly for summarization tasks. -""" +"utils" from collections import namedtuple as _namedtuple - - def normalize_text(text: str, keep_punc=False) -> str: """Lower text and remove punctuation, articles and extra whitespace. Copied from the [QuAC](http://quac.ai/) evaluation script found at https://s3.amazonaws.com/my89public/quac/scorer.py""" - def white_space_fix(text: str) -> str: return " ".join(text.split()) - def remove_punc(text: str) -> str: exclude = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" return "".join(ch for ch in text if ch not in exclude) - def lower(text: str) -> str: return text.lower() - if keep_punc: text = white_space_fix(lower(text)) else: text = white_space_fix(remove_punc(lower(text))) - if len(text) == 0: text = "." - return text - - def normalize(tokens, case=False): """ - Lowercases and turns tokens into distinct words. - """ - return [str(t).lower() if not case else str(t) for t in tokens] - - class Fragments: - """ - A class to compute and analyze matches between summary - and reference text, including coverage, density, - and compression metrics. - """ + "class" Match = _namedtuple("Match", ("summary", "text", "length")) - def __init__(self, summary, text, case=False): - # self._tokens = tokenize - if isinstance(summary, str): self.summary = summary.split() else: @@ -61,29 +36,20 @@ def __init__(self, summary, text, case=False): self.text = text.split() else: self.text = text - self._norm_summary = normalize(self.summary, case) self._norm_text = normalize(self.text, case) - self._match(self._norm_summary, self._norm_text) - def overlaps(self): """ - Return a list of Fragments.Match objects between summary and text. This is a list of named tuples of the form (summary, text, length): - - summary (int): the start index of the match in the summary - text (int): the start index of the match in the reference - length (int): the length of the extractive fragment - """ - return self._matches - def strings(self, min_length=0, summary_base=True): """ - Return a list of explicit match strings between the summary and reference. Note that this will be in the same format as the strings are input. @@ -91,34 +57,24 @@ def strings(self, min_length=0, summary_base=True): If tokenization is specified automatically on the raw strings, raw strings will automaticallybe returned rather than SpaCy tokenized sequences. - Arguments: - - min_length (int): filter out overlaps shorter than this (default = 0) - raw (bool): return raw input rather than stringified - (default = False if automatic tokenization, True otherwise) - summary_base (true): strings are based of summary text \ (default = True) - Returns: - - list of overlaps, where overlaps are strings or token sequences - """ - # Compute the strings against the summary or the text? - base = self.summary if summary_base else self.text - # Generate strings, filtering out strings below the minimum length. - strings = [ base[i:i + length] for i, j, length in self.overlaps() if length > min_length ] - # By default, we just return the tokenization being used. # But if they user wants a raw string, then we convert. # Mostly, this will be used along with spacy. @@ -129,141 +85,83 @@ def strings(self, min_length=0, summary_base=True): # strings[i] = str(s) # Return the list of strings. - return strings - def coverage(self, summary_base=True): """ Return the COVERAGE score of the summary and text. - Arguments: - - summary_base (bool): use summary as numerator (default = True) - Returns: - - decimal COVERAGE score within [0, 1] """ - numerator = sum(o.length for o in self.overlaps()) - - if summary_base: - denominator = len(self.summary) - else: - denominator = len(self.text) - + denominator = len(self.summary) if summary_base else len(self.text) if denominator == 0: return 0 return numerator / denominator def density(self, summary_base=True): """ - Return the DENSITY score of summary and text. - Arguments: - - summary_base (bool): use summary as numerator (default = True) - Returns: - - decimal DENSITY score within [0, ...] - """ - numerator = sum(o.length**2 for o in self.overlaps()) - - if summary_base: - denominator = len(self.summary) - else: - denominator = len(self.text) - + denominator = len(self.summary) if summary_base else len(self.text) if denominator == 0: return 0 return numerator / denominator def compression(self, text_to_summary=True): """ - Return compression ratio between summary and text. - Arguments: - - text_to_summary (bool): compute text/summary ratio\ (default = True) - Returns: - - decimal compression score within [0, ...] - """ - ratio = [len(self.text), len(self.summary)] - try: - if text_to_summary: return ratio[0] / ratio[1] return ratio[1] / ratio[0] - except ZeroDivisionError: - return 0 - def _match(self, a, b): """ - Raw procedure for matching summary in text, described in paper. - """ - self._matches = [] - a_start = b_start = 0 - while a_start < len(a): - best_match = None best_match_length = 0 - while b_start < len(b): - if a[a_start] == b[b_start]: - a_end = a_start b_end = b_start - while ( a_end < len(a) and b_end < len(b) and b[b_end] == a[a_end] ): - b_end += 1 a_end += 1 - length = a_end - a_start - if length > best_match_length: best_match = Fragments.Match(a_start, b_start, length) best_match_length = length - b_start = b_end - else: - b_start += 1 - b_start = 0 - if best_match: - if best_match_length > 0: self._matches.append(best_match) - a_start += best_match_length - else: - a_start += 1 diff --git a/src/melt/tools/pipelines/__information_retrieval.py b/src/melt/tools/pipelines/__information_retrieval.py new file mode 100644 index 0000000..614b6ce --- /dev/null +++ b/src/melt/tools/pipelines/__information_retrieval.py @@ -0,0 +1,227 @@ +"information retrieval" +import random +from tqdm import tqdm +from melt.tools.utils.utils import column, format_fewshot + +def __information_retrieval( + self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + if self.few_shot: + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.passages], + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + random_sample = list( + random.sample(list(ds_wrapper.dataset_training), 1) + )[0] + first_sample = { + "passages": random_sample["positive"], + "query": random_sample[ds_wrapper.dataset_info.query], + "references": ds_wrapper.dataset_info.label[0], + } + second_sample = { + "passages": random_sample["negative"], + "query": random_sample[ds_wrapper.dataset_info.query], + "references": ds_wrapper.dataset_info.label[1], + } + + selected_sample = [ + preprocessing_a_record(s) + for s in [first_sample, second_sample] + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + batch_passage_size = 10 + # Create few-shot strings + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + for query_with_a_batch_passages in range( + len(batch[ds_wrapper.dataset_info.type_id]) + ): + query_id = batch[ds_wrapper.dataset_info.type_id][ + query_with_a_batch_passages + ] + query = batch[ds_wrapper.dataset_info.query][ + query_with_a_batch_passages + ] + try: + ref_passage_id = batch[ds_wrapper.dataset_info.answer][0][ + query_with_a_batch_passages + ] + except IndexError: + if len(list(batch[ds_wrapper.dataset_info.answer])) < 1: + continue + ref_passage_id = list( + batch[ds_wrapper.dataset_info.answer][0] + )[query_with_a_batch_passages] + batch_passages = batch[ds_wrapper.dataset_info.passages] + + top30_passage_ids = column( + batch_passages["id"], query_with_a_batch_passages + ) + top30_passages = column( + batch_passages["passage"], query_with_a_batch_passages + ) + for psg in range( + 0, len(top30_passage_ids), batch_passage_size + ): + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + p, + query, + ), + }, + ] + for p in top30_passages[psg:psg + batch_passage_size] + ] + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt[ + "system_prompt" + ], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt[ + "prompt" + ].format( + p, + query, + ), + }, + ] + for p in top30_passages[psg:psg + batch_passage_size] + ] + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + + option_logprobs, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts * len(ds_wrapper.dataset_info.label), + [ + choice + for choice in ds_wrapper.dataset_info.label + for _ in range(len(prompts)) + ], + ) + ) + # Use a separate function to avoid cell-var-from-loop warnings + def create_prompt_dict(data): + return { + "query_id": ( + data['query_id'].item() + if not isinstance(data['query_id'], str) + else data['query_id'] + ), + "query": data['query'], + "passage_id": ( + data['passage_id'].item() if not isinstance( + data['passage_id'], str) else data['passage_id'] + ), + "passage": data['passage'], + "label": int( + data['passage_id'].item() == data['ref_passage_id'] + if not isinstance(data['passage_id'], str) + else data['passage_id'] == data['ref_passage_id'] + ), + "prediction": data['prediction'], + "generation_probs": data['generation_probs'], + "calib_probs": [ + data['option_logprobs'][data['q'] + opt * len(data['prompts'])] + for opt in range( + len(ds_wrapper.dataset_info.label) + ) + ], + } + save_each_prompt = [ + create_prompt_dict({ + 'prediction': x, + 'generation_probs': y, + 'passage_id': z, + 'passage': t, + 'q': q, + 'query_id': query_id, + 'query': query, + 'ref_passage_id': ref_passage_id, + 'option_logprobs': option_logprobs, + 'prompts': prompts + }) + for x, y, z, t, q in zip( + results, + logprobs, + top30_passage_ids[psg:psg + batch_passage_size], + top30_passages[psg:psg + batch_passage_size], + range(len(prompts)) + ) + ] + predictions.extend(save_each_prompt) + + idx += 1 + + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "fewshot": selected_sample, + "predictions": predictions, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ref_dataset=ds_wrapper.dataset_testing, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = {"fewshot": selected_sample, "predictions": predictions} + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ref_dataset=ds_wrapper.dataset_testing, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ref_dataset=ds_wrapper.dataset_testing, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__language_modeling.py b/src/melt/tools/pipelines/__language_modeling.py new file mode 100644 index 0000000..96c7e9e --- /dev/null +++ b/src/melt/tools/pipelines/__language_modeling.py @@ -0,0 +1,115 @@ +"language modeling" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __language_modeling( +self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + references = [] + generation_probs = [] + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + idx = 0 + original_few_shot = [] + selected_sample = [] + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.source], + rec[ds_wrapper.dataset_info.target], + ] + + selected_sample_idx = list( + random.sample( + range(len(ds_wrapper.dataset_training)), self.config.num_fs + ) + ) + selected_sample = [ + preprocessing_a_record(ds_wrapper.dataset_training[s]) + for s in selected_sample_idx + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + # Create few-shot strings + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.source] + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + predictions.extend(results) + references.extend( + references.extend(list(batch[ds_wrapper.dataset_info.target])) + ) + generation_probs.extend(logprobs) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__math.py b/src/melt/tools/pipelines/__math.py new file mode 100644 index 0000000..1492681 --- /dev/null +++ b/src/melt/tools/pipelines/__math.py @@ -0,0 +1,145 @@ +"math" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __math(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + predictions = [] + references = [] + generation_probs = [] + calib_probs = [] + math_problem_type = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend(self.continue_infer_data["generation_probs"]) + calib_probs.extend(self.continue_infer_data["calibration_probs"]) + math_problem_type.extend(self.continue_infer_data.get("math_problem_type", [])) + + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rf"{rec[ds_wrapper.dataset_info.query]}", + rf"{rec[ds_wrapper.dataset_info.answer]}", + ] + + selected_sample = [ + preprocessing_a_record(s) + for s in list( + random.sample( + list(ds_wrapper.dataset_training), self.config.num_fs + ) + ) + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + rf"{rule}" + ), + }, + ] + for rule in batch[ds_wrapper.dataset_info.query] + ] + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt["system_prompt"], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt["prompt"].format(rf"{rule}"), + }, + ] + for rule in batch[ds_wrapper.dataset_info.query] + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + calibprob_batch, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts, batch[ds_wrapper.dataset_info.answer] + ) + ) + predictions.extend(results) + references.extend(list(batch[ds_wrapper.dataset_info.answer])) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) + math_problem_type.extend(list(batch[ds_wrapper.dataset_info.type_id])) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + "math_problem_type": math_problem_type, + } + + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + "math_problem_type": math_problem_type, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__multiple_choice.py b/src/melt/tools/pipelines/__multiple_choice.py new file mode 100644 index 0000000..2680766 --- /dev/null +++ b/src/melt/tools/pipelines/__multiple_choice.py @@ -0,0 +1,215 @@ +"multiple choice" +import ast +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __multiple_choice(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + def format_list_ans(ans_list): + return "\n".join( + list( + map( + lambda ans: + f"{ds_wrapper.dataset_info.label[ans[0]]}: \ + ''' {ans[1]} '''", + enumerate(ans_list), + ) + ) + ) + + predictions = [] + references = [] + generation_probs = [] + option_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + option_order_all = [] + selected_sample = [] + # alphabet2idx = {chr(i + 65): i for i in range(26)} + num_choice = len(ds_wrapper.dataset_info.label) + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + option_probs.extend(self.continue_infer_data["option_probs"]) + option_order_all.extend(self.continue_infer_data["option_orders"]) + + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.context], + rec[ds_wrapper.dataset_info.query], + format_list_ans( + ast.literal_eval(rec[ds_wrapper.dataset_info.options]) + ), + rec[ds_wrapper.dataset_info.answer], + ] + + selected_sample_idx = list( + random.sample( + range(len(ds_wrapper.dataset_training)), self.config.num_fs + ) + ) + selected_sample = [ + preprocessing_a_record(ds_wrapper.dataset_training[s]) + for s in selected_sample_idx + ] + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + prompts = [] + calib_prompts = [] + remap_order_batch = [] + for cq in zip( + batch[ds_wrapper.dataset_info.context], + batch[ds_wrapper.dataset_info.query], + batch[ds_wrapper.dataset_info.options], + ): + c = cq[0] + q = cq[1] + opts = ast.literal_eval(cq[2]) + order_shuffle = list(range(len(opts))) + if ds_wrapper.dataset_info.random: + random.shuffle(order_shuffle) + remap_order_batch.append(order_shuffle) + new_opts = [opts[i] for i in order_shuffle] + prompts.append( + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + q, + format_list_ans(new_opts), + ), + }, + ] + ) + calib_prompts.append( + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt[ + "system_prompt" + ], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt[ + "prompt" + ].format( + c, + q, + format_list_ans(new_opts), + ), + }, + ] + ) + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + option_logprobs, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts * num_choice, + [ + ds_wrapper.dataset_info.label[choice] + for choice in range(num_choice) + for _ in range(len(prompts)) + ], + ) + ) + opt_calib_out = [ + [ + option_logprobs[i + opt * len(prompts)] + for opt in range(num_choice) + ] + for i in range(len(prompts)) + ] + + # Reshuffle answer of calib + option_order_all.extend(remap_order_batch) + predictions.extend(results) + # In case order of options is changed + # Map the reference to the new order + references.extend( + [ + ds_wrapper.dataset_info.label[ + remap.index(ds_wrapper.dataset_info.label.index(x)) + ] + for x, remap in zip( + batch[ds_wrapper.dataset_info.answer], + remap_order_batch, + ) + ] + ) + + generation_probs.extend(logprobs) + option_probs.extend(opt_calib_out) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, # new order + "generation_probs": generation_probs, + "option_probs": option_probs, # new order + "option_orders": option_order_all, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "option_orders": option_order_all, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__multiple_choice_sentiment.py b/src/melt/tools/pipelines/__multiple_choice_sentiment.py new file mode 100644 index 0000000..cf22219 --- /dev/null +++ b/src/melt/tools/pipelines/__multiple_choice_sentiment.py @@ -0,0 +1,170 @@ +"multiple choice sentiment" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot, unique + +def __multiple_choice_sentiment( + self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + references = [] + generation_probs = [] + option_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + num_choice = len(ds_wrapper.dataset_info.label) + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + option_probs.extend(self.continue_infer_data["option_probs"]) + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + classes = unique( + ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] + ) + selected_sample = [] + for cl in classes: + cl_samples = ds_wrapper.dataset_training.filter( + lambda r, class_label=cl: r[ds_wrapper.dataset_info.answer] == class_label + ) + selected_sample.append( + preprocessing_a_record( + cl_samples[random.randint(0, len(cl_samples) - 1)] + ) + ) + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt[ + "system_prompt" + ], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt[ + "prompt" + ].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + + option_logprobs, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts * num_choice, + [ + ds_wrapper.dataset_info.label[choice] + for choice in range(num_choice) + for _ in range(len(prompts)) + ], + ) + ) + predictions.extend(results) + references.extend( + [x.item() for x in batch[ds_wrapper.dataset_info.answer]] + ) + generation_probs.extend(logprobs) + option_probs.extend( + [ + [ + option_logprobs[i + opt * len(prompts)] + for opt in range(num_choice) + ] + for i in range(len(prompts)) + ] + ) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__multiple_choice_text_classification.py b/src/melt/tools/pipelines/__multiple_choice_text_classification.py new file mode 100644 index 0000000..fd3d1d4 --- /dev/null +++ b/src/melt/tools/pipelines/__multiple_choice_text_classification.py @@ -0,0 +1,172 @@ +"multiple choice test classification" +import random +from ast import literal_eval +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot, unique +def __multiple_choice_text_classification( + self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + references = [] + generation_probs = [] + option_probs = [] + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend(self.continue_infer_data["generation_probs"]) + option_probs.extend(self.continue_infer_data["option_probs"]) + + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + num_choice = len(ds_wrapper.dataset_info.label) + + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + classes = unique( + ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] + ) + + selected_sample = [] + for cl in classes: + cl_samples = ds_wrapper.dataset_training.filter( + lambda r, class_label=cl: r[ds_wrapper.dataset_info.answer] == class_label + ) + selected_sample.append( + cl_samples[random.randint(0, len(cl_samples) - 1)] + ) + + selected_sample = [ + preprocessing_a_record(x) for x in selected_sample + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt["system_prompt"], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt["prompt"].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + + option_logprobs, _ = self.infer_pipeline.compute_logprob_and_length( + calib_prompts * num_choice, + [ + ds_wrapper.dataset_info.label[choice] + for choice in range(num_choice) + for _ in range(len(prompts)) + ], + ) + + predictions.extend(results) + references.extend( + [ + literal_eval(x) if isinstance(x, str) else x.item() + for x in batch[ds_wrapper.dataset_info.answer] + ] + ) + generation_probs.extend(logprobs) + option_probs.extend( + [ + [ + option_logprobs[i + opt * len(prompts)] + for opt in range(num_choice) + ] + for i in range(len(prompts)) + ] + ) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__multiple_choice_toxicity.py b/src/melt/tools/pipelines/__multiple_choice_toxicity.py new file mode 100644 index 0000000..f0110af --- /dev/null +++ b/src/melt/tools/pipelines/__multiple_choice_toxicity.py @@ -0,0 +1,167 @@ +"multiple choice toxicity" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot, unique +def __multiple_choice_toxicity( +self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + predictions = [] + references = [] + generation_probs = [] + option_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + num_choice = len(ds_wrapper.dataset_info.label) + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + option_probs.extend(self.continue_infer_data["option_probs"]) + if self.few_shot: + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + classes = unique( + ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] + ) + selected_sample = [] + for class_label in classes: + cl_samples = ds_wrapper.dataset_training.filter( + lambda r, cl=class_label: r[ds_wrapper.dataset_info.answer] == cl + ) + selected_sample.append( + preprocessing_a_record( + cl_samples[random.randint(0, len(cl_samples) - 1)] + ) + ) + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt[ + "system_prompt" + ], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt[ + "prompt" + ].format( + c, + ), + }, + ] + for c in batch[ds_wrapper.dataset_info.query] + ] + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + + option_logprobs, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts * num_choice, + [ + ds_wrapper.dataset_info.label[choice] + for choice in range(num_choice) + for _ in range(len(prompts)) + ], + ) + ) + predictions.extend(results) + references.extend( + [x.item() for x in batch[ds_wrapper.dataset_info.answer]] + ) + generation_probs.extend(logprobs) + option_probs.extend( + [ + [ + option_logprobs[i + opt * len(prompts)] + for opt in range(num_choice) + ] + for i in range(len(prompts)) + ] + ) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "option_probs": option_probs, + "fewshot": selected_sample, + } + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__question_answering.py b/src/melt/tools/pipelines/__question_answering.py new file mode 100644 index 0000000..1d8231a --- /dev/null +++ b/src/melt/tools/pipelines/__question_answering.py @@ -0,0 +1,119 @@ +"__question_answering" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __question_answering( + self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + references = [] + generation_probs = [] + original_few_shot = [] + selected_sample = [] + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + idx = 0 + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.context], + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer]["text"][0], + ] + + selected_sample_idx = list( + random.sample( + range(len(ds_wrapper.dataset_training)), self.config.num_fs + ) + ) + selected_sample = [ + preprocessing_a_record(ds_wrapper.dataset_training[s]) + for s in selected_sample_idx + ] + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + c, + q, + ), + }, + ] + for c, q in zip( + batch[ds_wrapper.dataset_info.context], + batch[ds_wrapper.dataset_info.query], + ) + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + predictions.extend(results) + references.extend( + [x[0] for x in batch[ds_wrapper.dataset_info.answer]["text"]] + ) + generation_probs.extend(logprobs) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__question_answering_without_context.py b/src/melt/tools/pipelines/__question_answering_without_context.py new file mode 100644 index 0000000..e7b06fb --- /dev/null +++ b/src/melt/tools/pipelines/__question_answering_without_context.py @@ -0,0 +1,150 @@ +"question_answering_without context" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __question_answering_without_context( + self, ds_wrapper, ds_loader, saving_fn, start_idx=0 +): + predictions = [] + references = [] + generation_probs = [] + calib_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + calib_probs.extend(self.continue_infer_data["calibration_probs"]) + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + selected_sample_idx = list( + random.sample( + range(len(ds_wrapper.dataset_training)), self.config.num_fs + ) + ) + selected_sample = [ + preprocessing_a_record(ds_wrapper.dataset_training[s]) + for s in selected_sample_idx + ] + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + q, + ), + }, + ] + for q in batch[ds_wrapper.dataset_info.query] + ] + + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt[ + "system_prompt" + ], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt[ + "prompt" + ].format( + q, + ), + }, + ] + for q in batch[ds_wrapper.dataset_info.query] + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + calibprob_batch, _ = ( + self.infer_pipeline.compute_logprob_and_length( + calib_prompts, batch[ds_wrapper.dataset_info.answer] + ) + ) + predictions.extend(results) + references.extend(list(batch[ds_wrapper.dataset_info.answer])) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__reasoning.py b/src/melt/tools/pipelines/__reasoning.py new file mode 100644 index 0000000..d36a749 --- /dev/null +++ b/src/melt/tools/pipelines/__reasoning.py @@ -0,0 +1,136 @@ +"reasoning" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + predictions = [] + references = [] + generation_probs = [] + calib_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend(self.continue_infer_data["generation_probs"]) + calib_probs.extend(self.continue_infer_data["calibration_probs"]) + + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], + ] + + selected_sample = [ + preprocessing_a_record(s) + for s in list( + random.sample( + list(ds_wrapper.dataset_training), self.config.num_fs + ) + ) + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format(rule), + }, + ] + for rule in batch[ds_wrapper.dataset_info.query] + ] + calib_prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.calibration_prompt["system_prompt"], + }, + *calib_few_shot, + { + "role": "user", + "content": ds_wrapper.calibration_prompt["prompt"].format(rule), + }, + ] + for rule in batch[ds_wrapper.dataset_info.query] + ] + + results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) + calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length( + calib_prompts, batch[ds_wrapper.dataset_info.answer] + ) + predictions.extend(results) + references.extend(list(batch[ds_wrapper.dataset_info.answer])) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__summarization.py b/src/melt/tools/pipelines/__summarization.py new file mode 100644 index 0000000..82ccb88 --- /dev/null +++ b/src/melt/tools/pipelines/__summarization.py @@ -0,0 +1,118 @@ +"__summarization" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot + +def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + original_documents = [] + predictions = [] + original_few_shot = [] + selected_sample = [] + references = [] + generation_probs = [] + if self.continue_infer_data is not None: + original_documents.extend( + self.continue_infer_data["original_documents"] + ) + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend( + self.continue_infer_data["generation_probs"] + ) + idx = 0 + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.source], + rec[ds_wrapper.dataset_info.target], + ] + + selected_sample_idx = list( + random.sample( + range(len(ds_wrapper.dataset_training)), self.config.num_fs + ) + ) + selected_sample = [ + preprocessing_a_record(ds_wrapper.dataset_training[s]) + for s in selected_sample_idx + ] + + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + document, + ), + }, + ] + for document in batch[ds_wrapper.dataset_info.source] + ] + original_documents.extend(list(batch[ds_wrapper.dataset_info.source])) + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + predictions.extend(results) + references.extend(list(batch[ds_wrapper.dataset_info.target])) + generation_probs.extend(logprobs) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "original_documents": original_documents, + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "original_documents": original_documents, + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/__translation.py b/src/melt/tools/pipelines/__translation.py new file mode 100644 index 0000000..a560723 --- /dev/null +++ b/src/melt/tools/pipelines/__translation.py @@ -0,0 +1,114 @@ +"translation" +import random +from tqdm import tqdm +from melt.tools.utils.utils import format_fewshot +def __translation(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + predictions = [] + references = [] + generation_probs = [] + idx = 0 + original_few_shot = [] + selected_sample = [] + + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend(self.continue_infer_data["generation_probs"]) + + if self.few_shot: + + def preprocessing_a_record(rec): + return [ + rec[ds_wrapper.dataset_info.source], + rec[ds_wrapper.dataset_info.target], + ] + + selected_sample = [ + preprocessing_a_record(s) + for s in list( + random.sample( + list(ds_wrapper.dataset_training), self.config.num_fs + ) + ) + ] + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], + ) + + # Create few-shot strings + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue + + prompts = [ + [ + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format( + document, + ), + }, + ] + for document in batch[ds_wrapper.dataset_info.source] + ] + + results, logprobs, _ = self.infer_pipeline( + prompts, return_probs=True + ) + predictions.extend(results) + references.extend( + list(batch[ds_wrapper.dataset_info.target]) # Direct list instead of comprehension + ) + generation_probs.extend(logprobs) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result) diff --git a/src/melt/tools/pipelines/pipelines.py b/src/melt/tools/pipelines/pipelines.py index 232fdbd..87213a5 100644 --- a/src/melt/tools/pipelines/pipelines.py +++ b/src/melt/tools/pipelines/pipelines.py @@ -1,1971 +1,114 @@ -import ast -import torch +"pipelines" import os import json -from tqdm import tqdm -import random -from ..wrapper import ( +import torch +from melt.tools.wrapper import ( OpenAIWrapper, TGIWrapper, GeminiWrapper, VLLMWrapper, HFWrapper, ) -from ..utils.utils import column, format_fewshot, unique -from .metric_pipelines import MetricPipeline - - +from melt.tools.pipelines.metric_pipelines import MetricPipeline +from melt.tools.pipelines.__question_answering import __question_answering +from melt.tools.pipelines.__question_answering_without_context import ( + __question_answering_without_context +) +from melt.tools.pipelines.__summarization import __summarization +from melt.tools.pipelines.__multiple_choice_sentiment import __multiple_choice_sentiment +from melt.tools.pipelines.__multiple_choice_text_classification import ( + __multiple_choice_text_classification) +from melt.tools.pipelines.__multiple_choice_toxicity import __multiple_choice_toxicity +from melt.tools.pipelines.__multiple_choice import __multiple_choice +from melt.tools.pipelines.__language_modeling import __language_modeling +from melt.tools.pipelines.__information_retrieval import __information_retrieval +from melt.tools.pipelines.__reasoning import __reasoning +from melt.tools.pipelines.__math import __math +from melt.tools.pipelines.__translation import __translation class EvalPipeline: + "class" def __init__(self, task, config): - # Load generation configuration with open( os.path.join( - config.config_dir, config.lang, "generation_config.json" - ), - "r", + config.config_dir, config.lang, "generation_config.json"), "r", encoding="utf-8" ) as f: - GenerationConfig = json.load(f) + generation_config = json.load(f) with open( - os.path.join(config.config_dir, "llm_template.json"), "r" + os.path.join(config.config_dir, "llm_template.json"), "r", encoding="utf-8" ) as f: - LLM_TEMPLATE = json.load(f) + llm_template = json.load(f) with open( os.path.join( - config.config_dir, config.lang, "metric_configuration.json" - ), - "r", + config.config_dir, config.lang, "metric_configuration.json"), "r", encoding="utf-8" ) as f: - METRIC_CONFIG = json.load(f) + metric_config = json.load(f) + # Load task self.task_name = task # Load pipelines - # print(config.tgi) if config.wtype == "tgi": self.infer_pipeline = TGIWrapper( - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "hf": self.infer_pipeline = HFWrapper( config=config, - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "vllm": self.infer_pipeline = VLLMWrapper( config=config, - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "openai": self.infer_pipeline = OpenAIWrapper( engine=config.model_name, - generation_config=GenerationConfig[self.task_name], + generation_config=generation_config[self.task_name], ) elif config.wtype == "gemini": self.infer_pipeline = GeminiWrapper( model_name=config.model_name, - generation_config=GenerationConfig[self.task_name], + generation_config=generation_config[self.task_name], ) else: raise ValueError("Invalid wrapper type") self.config = config self.config.task = self.task_name - self.config.metric_config = METRIC_CONFIG + self.config.metric_config = metric_config self.few_shot = False self.continue_infer_data = None - # Metric pipeline configuration self.metric_pipeline = MetricPipeline() self.config.filepath = None + self.generation_results_file = None # Initialize in __init__ def __call__(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - task = self.task_name - - if task == "question-answering": - return self.__question_answering( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "summarization": - return self.__summarization( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "translation" in task: - return self.__translation( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "language-modeling" in task: - return self.__language_modeling( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "text-classification" in task: - return self.__multiple_choice_text_classification( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "sentiment-analysis": - return self.__multiple_choice_sentiment( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "toxicity-detection": - return self.__multiple_choice_toxicity( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "knowledge-mtpchoice": - return self.__multiple_choice( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "knowledge-openended": - return self.__question_answering_without_context( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "information-retrieval": - return self.__information_retrieval( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "reasoning": - return self.__reasoning( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "math": - return self.__math(ds_wrapper, ds_loader, saving_fn, start_idx) - else: - raise NotImplementedError - - def __question_answering( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - original_few_shot = [] - selected_sample = [] - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - idx = 0 - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.context], - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer]["text"][0], - ] - - selected_sample_idx = list( - random.sample( - range(len(ds_wrapper.dataset_training)), self.config.num_fs - ) - ) - selected_sample = [ - preprocessing_a_record(ds_wrapper.dataset_training[s]) - for s in selected_sample_idx - ] - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - q, - ), - }, - ] - for c, q in zip( - batch[ds_wrapper.dataset_info.context], - batch[ds_wrapper.dataset_info.query], - ) - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - predictions.extend(results) - references.extend( - [x[0] for x in batch[ds_wrapper.dataset_info.answer]["text"]] - ) - generation_probs.extend(logprobs) - - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, + task_mapping = { + "question-answering": __question_answering, + "summarization": __summarization, + "translation": __translation, + "language-modeling": __language_modeling, + "text-classification": __multiple_choice_text_classification, + "sentiment-analysis": __multiple_choice_sentiment, + "toxicity-detection": __multiple_choice_toxicity, + "knowledge-mtpchoice": __multiple_choice, + "knowledge-openended": __question_answering_without_context, + "information-retrieval": __information_retrieval, + "reasoning": __reasoning, + "math": __math, } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __question_answering_without_context( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - calib_probs = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - calib_probs.extend(self.continue_infer_data["calibration_probs"]) - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - selected_sample_idx = list( - random.sample( - range(len(ds_wrapper.dataset_training)), self.config.num_fs - ) - ) - selected_sample = [ - preprocessing_a_record(ds_wrapper.dataset_training[s]) - for s in selected_sample_idx - ] - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - q, - ), - }, - ] - for q in batch[ds_wrapper.dataset_info.query] - ] - - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - q, - ), - }, - ] - for q in batch[ds_wrapper.dataset_info.query] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - calibprob_batch, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts, batch[ds_wrapper.dataset_info.answer] - ) - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.answer]] - ) - generation_probs.extend(logprobs) - calib_probs.extend(calibprob_batch) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - } - - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - original_documents = [] - predictions = [] - original_few_shot = [] - selected_sample = [] - references = [] - generation_probs = [] - if self.continue_infer_data is not None: - original_documents.extend( - self.continue_infer_data["original_documents"] - ) - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - idx = 0 - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.source], - rec[ds_wrapper.dataset_info.target], - ] - - selected_sample_idx = list( - random.sample( - range(len(ds_wrapper.dataset_training)), self.config.num_fs - ) - ) - selected_sample = [ - preprocessing_a_record(ds_wrapper.dataset_training[s]) - for s in selected_sample_idx - ] - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - document, - ), - }, - ] - for document in batch[ds_wrapper.dataset_info.source] - ] - original_documents.extend( - [x for x in batch[ds_wrapper.dataset_info.source]] - ) - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.target]] - ) - generation_probs.extend(logprobs) - - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "original_documents": original_documents, - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "original_documents": original_documents, - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __multiple_choice_sentiment( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - option_probs = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - num_choice = len(ds_wrapper.dataset_info.label) - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - option_probs.extend(self.continue_infer_data["option_probs"]) - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - classes = unique( - ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] - ) - selected_sample = [] - for cl in classes: - cl_samples = ds_wrapper.dataset_training.filter( - lambda r: r[ds_wrapper.dataset_info.answer] == cl - ) - selected_sample.append( - preprocessing_a_record( - cl_samples[random.randint(0, len(cl_samples))] - ) - ) - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - - option_logprobs, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts * num_choice, - [ - ds_wrapper.dataset_info.label[choice] - for choice in range(num_choice) - for _ in range(len(prompts)) - ], - ) - ) - predictions.extend(results) - references.extend( - [x.item() for x in batch[ds_wrapper.dataset_info.answer]] - ) - generation_probs.extend(logprobs) - option_probs.extend( - [ - [ - option_logprobs[i + opt * len(prompts)] - for opt in range(num_choice) - ] - for i in range(len(prompts)) - ] - ) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __multiple_choice_text_classification( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - option_probs = [] - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - option_probs.extend(self.continue_infer_data["option_probs"]) - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - num_choice = len(ds_wrapper.dataset_info.label) - - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - classes = unique( - ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] - ) - - selected_sample = [] - for cl in classes: - cl_samples = ds_wrapper.dataset_training.filter( - lambda r: (r[ds_wrapper.dataset_info.answer] == cl) - ) - selected_sample.append( - cl_samples[random.randint(0, len(cl_samples) - 1)] - ) - - selected_sample = [ - preprocessing_a_record(x) for x in selected_sample - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - - option_logprobs, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts * num_choice, - [ - ds_wrapper.dataset_info.label[choice] - for choice in range(num_choice) - for _ in range(len(prompts)) - ], - ) - ) - predictions.extend(results) - references.extend( - [ - eval(x) if type(x) is str else x.item() - for x in batch[ds_wrapper.dataset_info.answer] - ] - ) - generation_probs.extend(logprobs) - option_probs.extend( - [ - [ - option_logprobs[i + opt * len(prompts)] - for opt in range(num_choice) - ] - for i in range(len(prompts)) - ] - ) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __multiple_choice_toxicity( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - option_probs = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - num_choice = len(ds_wrapper.dataset_info.label) - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - option_probs.extend(self.continue_infer_data["option_probs"]) - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - classes = unique( - ds_wrapper.dataset_training[ds_wrapper.dataset_info.answer] - ) - selected_sample = [] - for cl in classes: - cl_samples = ds_wrapper.dataset_training.filter( - lambda r: r[ds_wrapper.dataset_info.answer] == cl - ) - selected_sample.append( - preprocessing_a_record( - cl_samples[random.randint(0, len(cl_samples))] - ) - ) - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.query] - ] - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - - option_logprobs, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts * num_choice, - [ - ds_wrapper.dataset_info.label[choice] - for choice in range(num_choice) - for _ in range(len(prompts)) - ], - ) - ) - predictions.extend(results) - references.extend( - [x.item() for x in batch[ds_wrapper.dataset_info.answer]] - ) - generation_probs.extend(logprobs) - option_probs.extend( - [ - [ - option_logprobs[i + opt * len(prompts)] - for opt in range(num_choice) - ] - for i in range(len(prompts)) - ] - ) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __multiple_choice(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - def format_list_ans(ans_list): - return "\n".join( - list( - map( - lambda ans: - f"{ds_wrapper.dataset_info.label[ans[0]]}: \ - ''' {ans[1]} '''", - enumerate(ans_list), - ) - ) - ) - - predictions = [] - references = [] - generation_probs = [] - option_probs = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - option_order_all = [] - selected_sample = [] - # alphabet2idx = {chr(i + 65): i for i in range(26)} - num_choice = len(ds_wrapper.dataset_info.label) - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - option_probs.extend(self.continue_infer_data["option_probs"]) - option_order_all.extend(self.continue_infer_data["option_orders"]) - - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.context], - rec[ds_wrapper.dataset_info.query], - format_list_ans( - ast.literal_eval(rec[ds_wrapper.dataset_info.options]) - ), - rec[ds_wrapper.dataset_info.answer], - ] - - selected_sample_idx = list( - random.sample( - range(len(ds_wrapper.dataset_training)), self.config.num_fs - ) - ) - selected_sample = [ - preprocessing_a_record(ds_wrapper.dataset_training[s]) - for s in selected_sample_idx - ] - - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [] - calib_prompts = [] - remap_order_batch = [] - for cq in zip( - batch[ds_wrapper.dataset_info.context], - batch[ds_wrapper.dataset_info.query], - batch[ds_wrapper.dataset_info.options], - ): - - c = cq[0] - q = cq[1] - opts = ast.literal_eval(cq[2]) - order_shuffle = list(range(len(opts))) - if ds_wrapper.dataset_info.random: - random.shuffle(order_shuffle) - remap_order_batch.append(order_shuffle) - new_opts = [opts[i] for i in order_shuffle] - prompts.append( - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - q, - format_list_ans(new_opts), - ), - }, - ] - ) - calib_prompts.append( - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - c, - q, - format_list_ans(new_opts), - ), - }, - ] - ) - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - option_logprobs, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts * num_choice, - [ - ds_wrapper.dataset_info.label[choice] - for choice in range(num_choice) - for _ in range(len(prompts)) - ], - ) - ) - opt_calib_out = [ - [ - option_logprobs[i + opt * len(prompts)] - for opt in range(num_choice) - ] - for i in range(len(prompts)) - ] - - # REsort answer of calib - option_order_all.extend(remap_order_batch) - predictions.extend(results) - # In case order of options is changed - # Map the reference to the new order - references.extend( - [ - ds_wrapper.dataset_info.label[ - remap.index(ds_wrapper.dataset_info.label.index(x)) - ] - for x, remap in zip( - batch[ds_wrapper.dataset_info.answer], - remap_order_batch, - ) - ] - ) - - generation_probs.extend(logprobs) - option_probs.extend(opt_calib_out) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, # new order - "generation_probs": generation_probs, - "option_probs": option_probs, # new order - "option_orders": option_order_all, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "option_probs": option_probs, - "option_orders": option_order_all, - "fewshot": selected_sample, - } - - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __language_modeling( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - references = [] - generation_probs = [] - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - idx = 0 - original_few_shot = [] - selected_sample = [] - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.source], - rec[ds_wrapper.dataset_info.target], - ] - - selected_sample_idx = list( - random.sample( - range(len(ds_wrapper.dataset_training)), self.config.num_fs - ) - ) - selected_sample = [ - preprocessing_a_record(ds_wrapper.dataset_training[s]) - for s in selected_sample_idx - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - # Create few-shot strings - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - c, - ), - }, - ] - for c in batch[ds_wrapper.dataset_info.source] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.target]] - ) - generation_probs.extend(logprobs) - - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __information_retrieval( - self, ds_wrapper, ds_loader, saving_fn, start_idx=0 - ): - predictions = [] - # sub_task = self.task.split("_")[1] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.passages], - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - random_sample = list( - random.sample(list(ds_wrapper.dataset_training), 1) - )[0] - first_sample = { - "passages": random_sample["positive"], - "query": random_sample[ds_wrapper.dataset_info.query], - "references": ds_wrapper.dataset_info.label[0], - } - second_sample = { - "passages": random_sample["negative"], - "query": random_sample[ds_wrapper.dataset_info.query], - "references": ds_wrapper.dataset_info.label[1], - } - - selected_sample = [ - preprocessing_a_record(s) - for s in [first_sample, second_sample] - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - BATCH_PASSAGE_SIZE = 10 - # Create few-shot strings - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - for query_with_a_batch_passages in range( - len(batch[ds_wrapper.dataset_info.type_id]) - ): - query_id = batch[ds_wrapper.dataset_info.type_id][ - query_with_a_batch_passages - ] - query = batch[ds_wrapper.dataset_info.query][ - query_with_a_batch_passages - ] - try: - ref_passage_id = batch[ds_wrapper.dataset_info.answer][0][ - query_with_a_batch_passages - ] - except Exception: - if len(list(batch[ds_wrapper.dataset_info.answer])) < 1: - continue - ref_passage_id = list( - batch[ds_wrapper.dataset_info.answer][0] - )[query_with_a_batch_passages] - batch_passages = batch[ds_wrapper.dataset_info.passages] - - top30_passage_ids = column( - batch_passages["id"], query_with_a_batch_passages - ) - top30_passages = column( - batch_passages["passage"], query_with_a_batch_passages - ) - for psg in range( - 0, len(top30_passage_ids), BATCH_PASSAGE_SIZE - ): - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - p, - query, - ), - }, - ] - for p in top30_passages[psg:psg + BATCH_PASSAGE_SIZE] - ] - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format( - p, - query, - ), - }, - ] - for p in top30_passages[psg:psg + BATCH_PASSAGE_SIZE] - ] - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - - option_logprobs, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts * len(ds_wrapper.dataset_info.label), - [ - choice - for choice in ds_wrapper.dataset_info.label - for _ in range(len(prompts)) - ], - ) - ) - save_each_prompt = list( - map( - lambda x, y, z, t, q: { - "query_id": ( - query_id.item() - if type(query_id) is not str - else query_id - ), - "query": query, - "passage_id": ( - z.item() if type(z) is not str else z - ), - "passage": t, - "label": int( - z.item() == ref_passage_id - if type(z) is not str - else z == ref_passage_id - ), - "prediction": x, - "generation_probs": y, - "calib_probs": [ - option_logprobs[q + opt * len(prompts)] - for opt in range( - len(ds_wrapper.dataset_info.label) - ) - ], - }, - results, - logprobs, - top30_passage_ids[psg:psg + BATCH_PASSAGE_SIZE], - top30_passages[psg:psg + BATCH_PASSAGE_SIZE], - range(len(prompts)), - ) - ) - predictions.extend(save_each_prompt) - - idx += 1 - - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "fewshot": selected_sample, - "predictions": predictions, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ref_dataset=ds_wrapper.dataset_testing, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = {"fewshot": selected_sample, "predictions": predictions} - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ref_dataset=ds_wrapper.dataset_testing, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ref_dataset=ds_wrapper.dataset_testing, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - predictions = [] - references = [] - generation_probs = [] - calib_probs = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - calib_probs.extend(self.continue_infer_data["calibration_probs"]) - - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.query], - rec[ds_wrapper.dataset_info.answer], - ] - - selected_sample = [ - preprocessing_a_record(s) - for s in list( - random.sample( - list(ds_wrapper.dataset_training), self.config.num_fs - ) - ) - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format(rule), - }, - ] - for rule in batch[ds_wrapper.dataset_info.query] - ] - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format(rule), - }, - ] - for rule in batch[ds_wrapper.dataset_info.query] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - calibprob_batch, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts, batch[ds_wrapper.dataset_info.answer] - ) - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.answer]] - ) - generation_probs.extend(logprobs) - calib_probs.extend(calibprob_batch) - - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - } - - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - } - - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __math(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - predictions = [] - references = [] - generation_probs = [] - calib_probs = [] - math_problem_type = [] - idx = 0 - original_few_shot = [] - calib_few_shot = [] - selected_sample = [] - # res_list = pattern.findall(text) - # return res_list[0] if res_list else None - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - calib_probs.extend(self.continue_infer_data["calibration_probs"]) - math_problem_type.extend( - self.continue_infer_data.get("math_problem_type", []) - ) - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rf"{rec[ds_wrapper.dataset_info.query]}", - rf"{rec[ds_wrapper.dataset_info.answer]}", - ] - - selected_sample = [ - preprocessing_a_record(s) - for s in list( - random.sample( - list(ds_wrapper.dataset_training), self.config.num_fs - ) - ) - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - calib_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.calibration_prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - rf"{rule}" - ), - }, - ] - for rule in batch[ds_wrapper.dataset_info.query] - ] - calib_prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.calibration_prompt[ - "system_prompt" - ], - }, - *calib_few_shot, - { - "role": "user", - "content": ds_wrapper.calibration_prompt[ - "prompt" - ].format(rf"{rule}"), - }, - ] - for rule in batch[ds_wrapper.dataset_info.query] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - calibprob_batch, _ = ( - self.infer_pipeline.compute_logprob_and_length( - calib_prompts, batch[ds_wrapper.dataset_info.answer] - ) - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.answer]] - ) - generation_probs.extend(logprobs) - calib_probs.extend(calibprob_batch) - math_problem_type.extend( - [x for x in batch[ds_wrapper.dataset_info.type_id]] - ) - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - "math_problem_type": math_problem_type, - } - - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "calibration_probs": calib_probs, - "fewshot": selected_sample, - "math_problem_type": math_problem_type, - } - - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def __translation(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - predictions = [] - references = [] - generation_probs = [] - idx = 0 - original_few_shot = [] - selected_sample = [] - if self.continue_infer_data is not None: - predictions.extend(self.continue_infer_data["predictions"]) - references.extend(self.continue_infer_data["references"]) - generation_probs.extend( - self.continue_infer_data["generation_probs"] - ) - if self.few_shot: - - def preprocessing_a_record(rec): - return [ - rec[ds_wrapper.dataset_info.source], - rec[ds_wrapper.dataset_info.target], - ] - - selected_sample = [ - preprocessing_a_record(s) - for s in list( - random.sample( - list(ds_wrapper.dataset_training), self.config.num_fs - ) - ) - ] - original_few_shot = format_fewshot( - selected_sample, - query_format=ds_wrapper.prompt["prompt"], - answer_format=ds_wrapper.prompt["answer_format"], - ) - - # Create few-shot strings - for batch in tqdm(ds_loader): - if idx < start_idx: - idx += 1 - continue - - prompts = [ - [ - { - "role": "system", - "content": ds_wrapper.prompt["system_prompt"], - }, - *original_few_shot, - { - "role": "user", - "content": ds_wrapper.prompt["prompt"].format( - document, - ), - }, - ] - for document in batch[ds_wrapper.dataset_info.source] - ] - - results, logprobs, _ = self.infer_pipeline( - prompts, return_probs=True - ) - predictions.extend(results) - references.extend( - [x for x in batch[ds_wrapper.dataset_info.target]] - ) - generation_probs.extend(logprobs) - - idx += 1 - if idx % 100 == 0: - print(f"Saving results of {idx} batches") - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - saving_fn(generations) - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - print(f"Results of {idx} batches: ", mean_result) - - generations = { - "predictions": predictions, - "references": references, - "generation_probs": generation_probs, - "fewshot": selected_sample, - } - mean_result = self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - std_result = self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) + if self.task_name in task_mapping: + return task_mapping[self.task_name](ds_wrapper, ds_loader, saving_fn, start_idx) + raise NotImplementedError # Removed unnecessary "else" def run( self, ds_wrapper, @@ -1976,6 +119,7 @@ def run( few_shot=False, continue_infer=None, ): + "run" self.generation_results_file = generation_results_file self.config.filepath = generation_results_file self.continue_infer_data = continue_infer diff --git a/tests/test_execution.py b/tests/test_execution.py index a408cb6..621ce3c 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -40,12 +40,12 @@ def run_melt_command(self, dataset_name): "--seed", str(self.seed), "--smoke_test", str(self.smoke_test) ] - - result = subprocess.run(command, capture_output=True, text=True) + result = subprocess.run(command, capture_output=True, text=True, check=False) # Provide detailed error information if the command fails if result.returncode != 0: - self.fail(f"Command failed for dataset '{dataset_name}' with exit code {result.returncode}\n" + self.fail(f"Command failed for dataset '{dataset_name}' " + f"with exit code {result.returncode}\n" f"stdout: {result.stdout}\n" f"stderr: {result.stderr}") @@ -105,4 +105,4 @@ def test_information_retrieval(self): self.run_melt_command(dataset_name) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index cb1a679..20849e7 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -5,7 +5,6 @@ import subprocess import unittest - class TestWrapper(unittest.TestCase): """ Test cases for various wrapper types using the 'melt' command. @@ -31,21 +30,9 @@ def run_melt_command(self, dataset_name, wrapper_type): wrapper_type (str): The type of wrapper to use. """ command = [ - "melt", - "--wtype", - wrapper_type, - "--model_name", - self.model_name, - "--dataset_name", - dataset_name, - "--ptemplate", - self.ptemplate, - "--lang", - self.lang, - "--seed", - str(self.seed), - "--smoke_test", - str(self.smoke_test), + "melt", "--wtype", wrapper_type, "--model_name", self.model_name, + "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, + "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test) ] result = subprocess.run(command, capture_output=True, text=True, check=True) self.assertEqual(result.returncode, 0) @@ -85,5 +72,5 @@ def test_wrapper_vllm(self): dataset_name = "zalo_e2eqa" self.run_melt_command(dataset_name, "vllm") -if __name__ == "__main__": - unittest.main() +if __name__ == '__main__': + unittest.main() \ No newline at end of file