From bac7cf9cfaa8a191fcb134596514d7ad70e6a546 Mon Sep 17 00:00:00 2001 From: zhanghy-sketchzh <1750410339@qq.com> Date: Fri, 7 Jul 2023 22:36:27 +0800 Subject: [PATCH] finetune for qlora --- LICENSE | 222 +------ requirements.txt | 80 +++ src/__init__.py | 0 src/sql_data_process.py | 981 +++++++++++++++++++++++++++++++ src/train/train_qlora.py | 838 ++++++++++++++++++++++++++ src/utils/merge_peft_adapters.py | 48 ++ 6 files changed, 1968 insertions(+), 201 deletions(-) create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/sql_data_process.py create mode 100644 src/train/train_qlora.py create mode 100644 src/utils/merge_peft_adapters.py diff --git a/LICENSE b/LICENSE index 261eeb9..4cb947d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2023 magic.chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ff86055 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,80 @@ +torch==2.0.0 +accelerate==0.16.0 +aiohttp==3.8.4 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==22.2.0 +bitsandbytes==0.39.0 +cchardet==2.1.7 +chardet==5.1.0 +contourpy==1.0.7 +cycler==0.11.0 +filelock==3.9.0 +fonttools==4.38.0 +frozenlist==1.3.3 +huggingface-hub==0.14.1 +importlib-resources==5.12.0 + +sqlparse==0.4.4 +kiwisolver==1.4.4 +matplotlib==3.7.1 +multidict==6.0.4 +packaging==23.0 +psutil==5.9.4 +pycocotools==2.0.6 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pyyaml==6.0 +tokenizers==0.13.2 +tqdm==4.64.1 +transformers==4.30.0 +timm==0.6.13 +spacy==3.5.3 +webdataset==0.2.48 +yarl==1.8.2 +zipp==3.14.0 +omegaconf==2.3.0 +opencv-python==4.7.0.72 +iopath==0.1.10 +tenacity==8.2.2 +peft +pycocoevalcap +cpm_kernels +umap-learn +notebook +gradio==3.23 +gradio-client==0.0.8 +wandb +llama-index==0.5.27 +pymysql +unstructured==0.6.3 +grpcio==1.47.5 +gpt4all==0.3.0 +diskcache==5.6.1 + +auto-gpt-plugin-template +pymdown-extensions +gTTS==2.3.1 +langchain +nltk +python-dotenv==1.0.0 +pymilvus==2.2.1 +vcrpy +chromadb==0.3.22 +markdown2 +colorama +playsound +distro +pypdf +weaviate-client + +# Testing dependencies +pytest +asynctest +pytest-asyncio +pytest-benchmark +pytest-cov +pytest-integration +pytest-mock +pytest-recording +pytesseract==0.3.10 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sql_data_process.py b/src/sql_data_process.py new file mode 100644 index 0000000..b1c6a9e --- /dev/null +++ b/src/sql_data_process.py @@ -0,0 +1,981 @@ +import os +import sys +import json +import re +import random +import sqlite3 +import datasets +import traceback +import difflib +import functools +import re +import random +import numpy as np +from tqdm import tqdm +from copy import deepcopy +from rapidfuzz import fuzz +from os.path import isfile, isdir, join, split, exists, splitext +from typing import List, Optional, Tuple, Dict +from typing import Optional, List, Dict, Callable +from dataclasses import dataclass, field +from datasets.dataset_dict import DatasetDict +from datasets.arrow_dataset import Dataset +from transformers.training_args import TrainingArguments +from datasets.arrow_dataset import Dataset +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + +def convert_fk_index(data): + fk_holder = [] + for fk in data["foreign_keys"]: + tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] + ref_cid, cid = None, None + try: + tid = data["table_names_original"].index(tn) + ref_tid = data["table_names_original"].index(ref_tn) + + for i, (tab_id, col_org) in enumerate(data["column_names_original"]): + if tab_id == ref_tid and ref_col == col_org: + ref_cid = i + elif tid == tab_id and col == col_org: + cid = i + if ref_cid and cid: + fk_holder.append([cid, ref_cid]) + except: + traceback.print_exc() + print("table_names_original: ", data["table_names_original"]) + print("finding tab name: ", tn, ref_tn) + sys.exit() + return fk_holder + + +def dump_db_json_schema(db, f): + """read table and column info""" + + conn = sqlite3.connect(db) + conn.execute("pragma foreign_keys=ON") + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") + + data = { + "db_id": f, + "table_names_original": [], + "table_names": [], + "column_names_original": [(-1, "*")], + "column_names": [(-1, "*")], + "column_types": ["text"], + "primary_keys": [], + "foreign_keys": [], + } + + fk_holder = [] + for i, item in enumerate(cursor.fetchall()): + table_name = item[0] + data["table_names_original"].append(table_name) + data["table_names"].append(table_name.lower().replace("_", " ")) + fks = conn.execute( + "PRAGMA foreign_key_list('{}') ".format(table_name) + ).fetchall() + # print("db:{} table:{} fks:{}".format(f,table_name,fks)) + fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) + cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) + for j, col in enumerate(cur.fetchall()): + data["column_names_original"].append((i, col[1])) + data["column_names"].append((i, col[1].lower().replace("_", " "))) + # varchar, '' -> text, int, numeric -> integer, + col_type = col[2].lower() + if ( + "char" in col_type + or col_type == "" + or "text" in col_type + or "var" in col_type + ): + data["column_types"].append("text") + elif ( + "int" in col_type + or "numeric" in col_type + or "decimal" in col_type + or "number" in col_type + or "id" in col_type + or "real" in col_type + or "double" in col_type + or "float" in col_type + ): + data["column_types"].append("number") + elif "date" in col_type or "time" in col_type or "year" in col_type: + data["column_types"].append("time") + elif "boolean" in col_type: + data["column_types"].append("boolean") + else: + data["column_types"].append("others") + + if col[5] == 1: + data["primary_keys"].append(len(data["column_names"]) - 1) + + data["foreign_keys"] = fk_holder + data["foreign_keys"] = convert_fk_index(data) + + return data + +logger = datasets.logging.get_logger(__name__) + +_CITATION = """\ +@article{yu2018spider, + title={Spider: A large-scale human-labeled dataset for complex and cross-domain semantic parsing and text-to-sql task}, + author={Yu, Tao and Zhang, Rui and Yang, Kai and Yasunaga, Michihiro and Wang, Dongxu and Li, Zifan and Ma, James and Li, Irene and Yao, Qingning and Roman, Shanelle and others}, + journal={arXiv preprint arXiv:1809.08887}, + year={2018} +} +""" + + +def generate_data(data_filepaths, db_path): + raw_data = [] + schema_cache = dict() + for data_filepath in data_filepaths: + with open(data_filepath, encoding="utf-8") as f: + spider = json.load(f) + for sample in spider: + db_id = sample["db_id"] + if db_id not in schema_cache: + schema_cache[db_id] = dump_db_json_schema(db=os.path.join(db_path, db_id, f"{db_id}.sqlite"), f=db_id) + schema = schema_cache[db_id] + raw_data.append({ + "query": sample["query"], + "question": sample["question"], + "db_id": db_id, + "db_path": db_path, + "db_table_names": schema["table_names_original"], + "db_column_names": [ + {"table_id": table_id, "column_name": column_name} + for table_id, column_name in schema["column_names_original"] + ], + "db_column_types": schema["column_types"], + "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]], + "db_foreign_keys": [ + {"column_id": column_id, "other_column_id": other_column_id} + for column_id, other_column_id in schema["foreign_keys"] + ], + }) + return raw_data + + + +""" + Wrap the raw dataset into the seq2seq one. + And the raw dataset item is formatted as + { + "query": sample["query"], + "question": sample["question"], + "db_id": db_id, + "db_path": db_path, + "db_table_names": schema["table_names_original"], + "db_column_names": [ + {"table_id": table_id, "column_name": column_name} + for table_id, column_name in schema["column_names_original"] + ], + "db_column_types": schema["column_types"], + "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]], + "db_foreign_keys": [ + {"column_id": column_id, "other_column_id": other_column_id} + for column_id, other_column_id in schema["foreign_keys"] + ], + } + """ + + +# fmt: off +_stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', + 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', + "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', + "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', + 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', + 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', + 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', + 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', + 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', + 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', + "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', + 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', + 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', + 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', + 'then', 'did', 'just', "aren't"} +# fmt: on + +_commonwords = {"no", "yes", "many"} + + +def is_number(s: str) -> bool: + try: + float(s.replace(",", "")) + return True + except: + return False + + +def is_stopword(s: str) -> bool: + return s.strip() in _stopwords + + +def is_commonword(s: str) -> bool: + return s.strip() in _commonwords + + +def is_common_db_term(s: str) -> bool: + return s.strip() in ["id"] + + +class Match(object): + def __init__(self, start: int, size: int) -> None: + self.start = start + self.size = size + + +def is_span_separator(c: str) -> bool: + return c in "'\"()`,.?! " + + +def split(s: str) -> List[str]: + return [c.lower() for c in s.strip()] + + +def prefix_match(s1: str, s2: str) -> bool: + i, j = 0, 0 + for i in range(len(s1)): + if not is_span_separator(s1[i]): + break + for j in range(len(s2)): + if not is_span_separator(s2[j]): + break + if i < len(s1) and j < len(s2): + return s1[i] == s2[j] + elif i >= len(s1) and j >= len(s2): + return True + else: + return False + + +def get_effective_match_source(s: str, start: int, end: int) -> Match: + _start = -1 + + for i in range(start, start - 2, -1): + if i < 0: + _start = i + 1 + break + if is_span_separator(s[i]): + _start = i + break + + if _start < 0: + return None + + _end = -1 + for i in range(end - 1, end + 3): + if i >= len(s): + _end = i - 1 + break + if is_span_separator(s[i]): + _end = i + break + + if _end < 0: + return None + + while _start < len(s) and is_span_separator(s[_start]): + _start += 1 + while _end >= 0 and is_span_separator(s[_end]): + _end -= 1 + + return Match(_start, _end - _start + 1) + + +def get_matched_entries( + s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 +) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: + if not field_values: + return None + + if isinstance(s, str): + n_grams = split(s) + else: + n_grams = s + + matched = dict() + for field_value in field_values: + if not isinstance(field_value, str): + continue + fv_tokens = split(field_value) + sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) + match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) + if match.size > 0: + source_match = get_effective_match_source( + n_grams, match.a, match.a + match.size + ) + if source_match and source_match.size > 1: + match_str = field_value[match.b : match.b + match.size] + source_match_str = s[ + source_match.start : source_match.start + source_match.size + ] + c_match_str = match_str.lower().strip() + c_source_match_str = source_match_str.lower().strip() + c_field_value = field_value.lower().strip() + if ( + c_match_str + and not is_number(c_match_str) + and not is_common_db_term(c_match_str) + ): + if ( + is_stopword(c_match_str) + or is_stopword(c_source_match_str) + or is_stopword(c_field_value) + ): + continue + if c_source_match_str.endswith(c_match_str + "'s"): + match_score = 1.0 + else: + if prefix_match(c_field_value, c_source_match_str): + match_score = ( + fuzz.ratio(c_field_value, c_source_match_str) / 100 + ) + else: + match_score = 0 + if ( + is_commonword(c_match_str) + or is_commonword(c_source_match_str) + or is_commonword(c_field_value) + ) and match_score < 1: + continue + s_match_score = match_score + if match_score >= m_theta and s_match_score >= s_theta: + if field_value.isupper() and match_score * s_match_score < 1: + continue + matched[match_str] = ( + field_value, + source_match_str, + match_score, + s_match_score, + match.size, + ) + + if not matched: + return None + else: + return sorted( + matched.items(), + key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), + reverse=True, + ) + + +@functools.lru_cache(maxsize=1000, typed=False) +def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: + fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) + try: + conn = sqlite3.connect(db_path) + conn.text_factory = bytes + c = conn.cursor() + c.execute(fetch_sql) + picklist = set() + for x in c.fetchall(): + if isinstance(x[0], str): + picklist.add(x[0].encode("utf-8")) + elif isinstance(x[0], bytes): + try: + picklist.add(x[0].decode("utf-8")) + except UnicodeDecodeError: + picklist.add(x[0].decode("latin-1")) + else: + picklist.add(x[0]) + picklist = list(picklist) + finally: + conn.close() + return picklist + + +def get_database_matches( + question: str, + table_name: str, + column_name: str, + db_path: str, + top_k_matches: int = 2, + match_threshold: float = 0.85, +) -> List[str]: + picklist = get_column_picklist( + table_name=table_name, column_name=column_name, db_path=db_path + ) + matches = [] + if picklist and isinstance(picklist[0], str): + matched_entries = get_matched_entries( + s=question, + field_values=picklist, + m_theta=match_threshold, + s_theta=match_threshold, + ) + if matched_entries: + num_values_inserted = 0 + for _match_str, ( + field_value, + _s_match_str, + match_score, + s_match_score, + _match_size, + ) in matched_entries: + if "name" in column_name and match_score * s_match_score < 1: + continue + if table_name != "sqlite_sequence": # Spider database artifact + matches.append(field_value) + num_values_inserted += 1 + if num_values_inserted >= top_k_matches: + break + return matches + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + val_max_time: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum allowed time in seconds for generation of one example. This setting can be used to stop " + "generation whenever the full generation exceeds the specified amount of time." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation or test examples to this " + "value if set." + }, + ) + num_beams: int = field( + default=1, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + num_beam_groups: int = field( + default=1, + metadata={ + "help": "Number of beam groups to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + diversity_penalty: Optional[float] = field( + default=None, + metadata={ + "help": "Diversity penalty to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + num_return_sequences: Optional[int] = field( + default=None, + metadata={ + "help": "The number of sequences to generate during evaluation. This argument will be passed to " + "``model.generate``, which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, + metadata={"help": "A prefix to add before every source text (useful for T5 models)."}, + ) + schema_serialization_type: str = field( + default="peteshaw", + metadata={"help": "Choose between ``verbose`` and ``peteshaw`` schema serialization."}, + ) + schema_serialization_randomized: bool = field( + default=False, + metadata={"help": "Whether or not to randomize the order of tables."}, + ) + schema_serialization_with_db_id: bool = field( + default=True, + metadata={"help": "Whether or not to add the database id to the context. Needed for Picard."}, + ) + schema_serialization_with_db_content: bool = field( + default=True, + metadata={"help": "Whether or not to use the database content to resolve field matches."}, + ) + normalize_query: bool = field(default=True, metadata={"help": "Whether to normalize the SQL queries."}) + target_with_db_id: bool = field( + default=True, + metadata={"help": "Whether or not to add the database id to the target. Needed for Picard."}, + ) + + def __post_init__(self): + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +@dataclass +class DataArguments: + dataset: str = field( + metadata={"help": "The dataset to be used. Choose between ``spider``, ``cosql``, or ``cosql+spider``, or ``spider_realistic``, or ``spider_syn``, or ``spider_dk``."}, + ) + dataset_paths: Dict[str, str] = field( + default_factory=lambda: { + "spider": "./seq2seq/datasets/spider", + "cosql": "./seq2seq/datasets/cosql", + "spider_realistic": "./seq2seq/datasets/spider_realistic", + "spider_syn": "./seq2seq/datasets/spider_syn", + "spider_dk": "./seq2seq/datasets/spider_dk" + + }, + metadata={"help": "Paths of the dataset modules."}, + ) + metric_config: str = field( + default="both", + metadata={"help": "Choose between ``exact_match``, ``test_suite``, or ``both``."}, + ) + #we are referencing spider_realistic to spider metrics only as both use the main spider dataset as base. + metric_paths: Dict[str, str] = field( + default_factory=lambda: { + "spider": "./seq2seq/metrics/spider", + "spider_realistic" : "./seq2seq/metrics/spider", + "cosql": "./seq2seq/metrics/cosql", + "spider_syn":"./seq2seq/metrics/spider", + "spider_dk":"./seq2seq/metrics/spider" + }, + metadata={"help": "Paths of the metric modules."}, + ) + test_suite_db_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the test-suite databases."}) + data_config_file : Optional[str] = field( + default=None, + metadata={"help": "Path to data configuration file (specifying the database splits)"} + ) + test_sections : Optional[List[str]] = field( + default=None, + metadata={"help": "Sections from the data config to use for testing"} + ) + + +@dataclass +class TrainSplit(object): + dataset: Dataset + schemas: Dict[str, dict] + + +@dataclass +class EvalSplit(object): + dataset: Dataset + examples: Dataset + schemas: Dict[str, dict] + + +@dataclass +class DatasetSplits(object): + train_split: Optional[TrainSplit] + eval_split: Optional[EvalSplit] + test_splits: Optional[Dict[str, EvalSplit]] + schemas: Dict[str, dict] + + +def _get_schemas(examples: Dataset) -> Dict[str, dict]: + schemas: Dict[str, dict] = dict() + for ex in examples: + if ex["db_id"] not in schemas: + schemas[ex["db_id"]] = { + "db_table_names": ex["db_table_names"], + "db_column_names": ex["db_column_names"], + "db_column_types": ex["db_column_types"], + "db_primary_keys": ex["db_primary_keys"], + "db_foreign_keys": ex["db_foreign_keys"], + } + return schemas + + +def _prepare_train_split( + dataset: Dataset, + data_training_args: DataTrainingArguments, + add_serialized_schema: Callable[[dict], dict], + pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], +) -> TrainSplit: + schemas = _get_schemas(examples=dataset) + dataset = dataset.map( + add_serialized_schema, + batched=False, + num_proc=data_training_args.preprocessing_num_workers, + load_from_cache_file=not data_training_args.overwrite_cache, + ) + if data_training_args.max_train_samples is not None: + dataset = dataset.select(range(data_training_args.max_train_samples)) + column_names = dataset.column_names + dataset = dataset.map( + lambda batch: pre_process_function( + batch=batch, + max_source_length=data_training_args.max_source_length, + max_target_length=data_training_args.max_target_length, + ), + batched=True, + num_proc=data_training_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_training_args.overwrite_cache, + ) + return TrainSplit(dataset=dataset, schemas=schemas) + + +def _prepare_eval_split( + dataset: Dataset, + data_training_args: DataTrainingArguments, + add_serialized_schema: Callable[[dict], dict], + pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], +) -> EvalSplit: + if (data_training_args.max_val_samples is not None + and data_training_args.max_val_samples < len(dataset)): + eval_examples = dataset.select(range(data_training_args.max_val_samples)) + else: + eval_examples = dataset + schemas = _get_schemas(examples=eval_examples) + eval_dataset = eval_examples.map( + add_serialized_schema, + batched=False, + num_proc=data_training_args.preprocessing_num_workers, + load_from_cache_file=not data_training_args.overwrite_cache, + ) + column_names = eval_dataset.column_names + eval_dataset = eval_dataset.map( + lambda batch: pre_process_function( + batch=batch, + max_source_length=data_training_args.max_source_length, + max_target_length=data_training_args.val_max_target_length, + ), + batched=True, + num_proc=data_training_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_training_args.overwrite_cache, + ) + return EvalSplit(dataset=eval_dataset, examples=eval_examples, schemas=schemas) + + +def prepare_splits( + dataset_dict: DatasetDict, + data_args: DataArguments, + training_args: TrainingArguments, + data_training_args: DataTrainingArguments, + add_serialized_schema: Callable[[dict], dict], + pre_process_function: Callable[[dict, Optional[int], Optional[int]], dict], +) -> DatasetSplits: + train_split, eval_split, test_splits = None, None, None + + if training_args.do_train: + train_split = _prepare_train_split( + dataset_dict["train"], + data_training_args=data_training_args, + add_serialized_schema=add_serialized_schema, + pre_process_function=pre_process_function, + ) + + if training_args.do_eval: + eval_split = _prepare_eval_split( + dataset_dict["validation"], + data_training_args=data_training_args, + add_serialized_schema=add_serialized_schema, + pre_process_function=pre_process_function, + ) + + if training_args.do_predict: + test_splits = { + section: _prepare_eval_split( + dataset_dict[section], + data_training_args=data_training_args, + add_serialized_schema=add_serialized_schema, + pre_process_function=pre_process_function, + ) + for section in data_args.test_sections + } + test_split_schemas = {} + for split in test_splits.values(): + test_split_schemas.update(split.schemas) + + schemas = { + **(train_split.schemas if train_split is not None else {}), + **(eval_split.schemas if eval_split is not None else {}), + **(test_split_schemas if test_splits is not None else {}), + } + + return DatasetSplits( + train_split=train_split, + eval_split=eval_split, + test_splits=test_splits, + schemas=schemas + ) + + +def normalize(query: str) -> str: + def comma_fix(s): + # Remove spaces in front of commas + return s.replace(" , ", ", ") + + def white_space_fix(s): + # Remove double and triple spaces + return " ".join(s.split()) + + def lower(s): + # Convert everything except text between (single or double) quotation marks to lower case + return re.sub(r"\b(? str: + if schema_serialization_type == "verbose": + db_id_str = "Database: {db_id}. " + table_sep = ". " + table_str = "Table: {table}. Columns: {columns}" + column_sep = ", " + column_str_with_values = "{column} ({values})" + column_str_without_values = "{column}" + value_sep = ", " + elif schema_serialization_type == "peteshaw": + # see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/append_schema.py#L42 + db_id_str = " | {db_id}" + table_sep = "" + table_str = " | {table} : {columns}" + column_sep = " , " + column_str_with_values = "{column} ( {values} )" + column_str_without_values = "{column}" + value_sep = " , " + else: + raise NotImplementedError + + def get_column_str(table_name: str, column_name: str) -> str: + column_name_str = column_name.lower() if normalize_query else column_name + if schema_serialization_with_db_content: + matches = get_database_matches( + question=question, + table_name=table_name, + column_name=column_name, + db_path=(db_path + "/" + db_id + "/" + db_id + ".sqlite"), + ) + if matches: + return column_str_with_values.format(column=column_name_str, values=value_sep.join(matches)) + else: + return column_str_without_values.format(column=column_name_str) + else: + return column_str_without_values.format(column=column_name_str) + + tables = [ + table_str.format( + table=table_name.lower() if normalize_query else table_name, + columns=column_sep.join( + map( + lambda y: get_column_str(table_name=table_name, column_name=y[1]), + filter( + lambda y: y[0] == table_id, + zip( + db_column_names["table_id"], + db_column_names["column_name"], + ), + ), + ) + ), + ) + for table_id, table_name in enumerate(db_table_names) + ] + if schema_serialization_randomized: + random.shuffle(tables) + if schema_serialization_with_db_id: + serialized_schema = db_id_str.format(db_id=db_id) + table_sep.join(tables) + else: + serialized_schema = table_sep.join(tables) + return serialized_schema + +def serialize_schema_natural_language( + question: str, + db_path: str, + db_id: str, + db_column_names: Dict[str, str], + db_table_names: List[str], + db_primary_keys, + db_foreign_keys, + schema_serialization_with_db_content: bool = False, + normalize_query: bool = True, +) -> str: + overall_description = f'{db_id} contains tables such as ' \ + f'{", ".join([table_name.lower() if normalize_query else table_name for table_name in db_table_names])}.' + table_description_primary_key_template = lambda table_name, primary_key: \ + f'{primary_key} is the primary key.' + table_description = lambda table_name, column_names: \ + f'Table {table_name} has columns such as {", ".join(column_names)}.' + value_description = lambda column_value_pairs: \ + f'{"".join(["The {} contains values such as {}.".format(column, value) for column, value in column_value_pairs])}' + foreign_key_description = lambda table_1, column_1, table_2, column_2: \ + f'The {column_1} of {table_1} is the foreign key of {column_2} of {table_2}.' + + db_primary_keys = [x["column_id"] for x in db_primary_keys] + db_foreign_keys = list(zip([x["column_id"] for x in db_foreign_keys], [x["other_column_id"] for x in db_foreign_keys])) + + + descriptions = [overall_description] + db_table_name_strs = [] + db_column_name_strs = [] + value_sep = ", " + for table_id, table_name in enumerate(db_table_names): + table_name_str = table_name.lower() if normalize_query else table_name + db_table_name_strs.append(table_name_str) + columns = [] + column_value_pairs = [] + primary_keys = [] + for column_id, (x, y) in enumerate(zip([x["table_id"] for x in db_column_names], [x["column_name"] for x in db_column_names])): + if column_id == 0: + continue + column_str = y.lower() if normalize_query else y + db_column_name_strs.append(column_str) + if x == table_id: + columns.append(column_str) + if column_id in db_primary_keys: + primary_keys.append(column_str) + if schema_serialization_with_db_content: + matches = get_database_matches( + question=question, + table_name=table_name, + column_name=y, + db_path=(db_path + "/" + db_id + "/" + db_id + ".sqlite"), + ) + if matches: + column_value_pairs.append((column_str, value_sep.join(matches))) + + table_description_columns_str = table_description(table_name_str, columns) + descriptions.append(table_description_columns_str) + table_description_primary_key_str = table_description_primary_key_template(table_name_str, ", ".join(primary_keys)) + descriptions.append(table_description_primary_key_str) + if len(column_value_pairs) > 0: + value_description_str = value_description(column_value_pairs) + descriptions.append(value_description_str) + + + for x, y in db_foreign_keys: + # get the table and column of x + db_column_names_table_id =[x["table_id"] for x in db_column_names] + x_table_name = db_table_name_strs[db_column_names_table_id[x]] + x_column_name = db_column_name_strs[x] + # get the table and column of y + y_table_name = db_table_name_strs[db_column_names_table_id[y]] + y_column_name = db_column_name_strs[y] + foreign_key_description_str = foreign_key_description(x_table_name, x_column_name, y_table_name, y_column_name) + descriptions.append(foreign_key_description_str) + return " ".join(descriptions) + +def spider_get_input( + question: str, + serialized_schema: str, + prefix: str, +) -> str: + return prefix + question.strip() + " " + serialized_schema.strip() + + +def spider_get_target( + query: str, + db_id: str, + normalize_query: bool, + target_with_db_id: bool, +) -> str: + _normalize = normalize if normalize_query else (lambda x: x) + return f"{db_id} | {_normalize(query)}" if target_with_db_id else _normalize(query) + +def spider_add_serialized_schema(ex: dict) -> dict: + + serialized_schema = serialize_schema_natural_language( + question=ex["question"], + db_path=ex["db_path"], + db_id=ex["db_id"], + db_column_names=ex["db_column_names"], + db_table_names=ex["db_table_names"], + db_primary_keys=ex["db_primary_keys"], + db_foreign_keys=ex["db_foreign_keys"], + schema_serialization_with_db_content=True, + normalize_query=True, + ) + + return {"serialized_schema": serialized_schema} + +def spider_pre_process_one_function(item: dict): + prefix = "" + + seq_out = spider_get_target( + query=item["query"], + db_id=item["db_id"], + normalize_query=True, + target_with_db_id=True, + ) + + return prefix + item["question"].strip(), seq_out + + +if __name__ == "__main__" : + + data_filepaths = ["data/spider/train_spider.json","data/spider/train_others.json"] + raw_datasets = generate_data(data_filepaths,"data/spider/database") + sql_fintune_data = [] + fields = ["instruction","input","response"] + for raw_data in tqdm(raw_datasets): + extend_data = deepcopy(raw_data) + extend_data.update(spider_add_serialized_schema(extend_data)) + question, seq_out = spider_pre_process_one_function(extend_data) + extend_data.update({"instruction": extend_data["serialized_schema"].strip(), + "input": question, + "output": seq_out}) + extended_data = {} + extended_data.update({key: value for key, value in extend_data.items() if key in fields}) + sql_fintune_data.append(extended_data) + sql_fintune_data + with open('sql_fintune_data.json', 'w') as f: + json.dump(sql_fintune_data, f) + print("The raw datasets has been generated") \ No newline at end of file diff --git a/src/train/train_qlora.py b/src/train/train_qlora.py new file mode 100644 index 0000000..a25e52e --- /dev/null +++ b/src/train/train_qlora.py @@ -0,0 +1,838 @@ +from collections import defaultdict +import copy +import json +import os +from os.path import exists, join, isdir +from dataclasses import dataclass, field +import sys +from typing import Optional, Dict, Sequence +import numpy as np +from tqdm import tqdm +import logging +import bitsandbytes as bnb +import pandas as pd + +import torch +import transformers +from torch.nn.utils.rnn import pad_sequence +import argparse +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + set_seed, + Seq2SeqTrainer, + BitsAndBytesConfig, + LlamaTokenizer + +) +from datasets import load_dataset, Dataset +import evaluate + +from peft import ( + prepare_model_for_kbit_training, + LoraConfig, + get_peft_model, + PeftModel +) +from peft.tuners.lora import LoraLayer +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) +from pilot.configs.config import Config +from pilot.configs.model_config import LLM_MODEL_CONFIG + +torch.backends.cuda.matmul.allow_tf32 = True + +logger = logging.getLogger(__name__) + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" + +CFG = Config() +model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field( + default=model_path + ) + trust_remote_code: Optional[bool] = field( + default=True, + metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={"help": "Enables using Huggingface auth token from Git Credentials."} + ) + +@dataclass +class DataArguments: + eval_dataset_size: int = field( + default=1024, metadata={"help": "Size of validation dataset."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + source_max_len: int = field( + default=1024, + metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + target_max_len: int = field( + default=256, + metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + dataset: str = field( + default='spider', + metadata={"help": "Which dataset to finetune on. See datamodule for options."} + ) + dataset_format: Optional[str] = field( + default=None, + metadata={"help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"} + ) + +@dataclass +class TrainingArguments(transformers.Seq2SeqTrainingArguments): + cache_dir: Optional[str] = field( + default=None + ) + train_on_source: Optional[bool] = field( + default=False, + metadata={"help": "Whether to train on the input in addition to the target text."} + ) + mmlu_split: Optional[str] = field( + default='eval', + metadata={"help": "The MMLU split to run on"} + ) + mmlu_dataset: Optional[str] = field( + default='mmlu-fs', + metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."} + ) + do_mmlu_eval: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run the MMLU evaluation."} + ) + max_mmlu_samples: Optional[int] = field( + default=None, + metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."} + ) + mmlu_source_max_len: int = field( + default=2048, + metadata={"help": "Maximum source sequence length for mmlu."} + ) + full_finetune: bool = field( + default=False, + metadata={"help": "Finetune the entire model without adapters."} + ) + adam8bit: bool = field( + default=False, + metadata={"help": "Use 8-bit adam."} + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=4, + metadata={"help": "How many bits to use."} + ) + lora_r: int = field( + default=64, + metadata={"help": "Lora R dimension."} + ) + lora_alpha: float = field( + default=16, + metadata={"help": " Lora alpha."} + ) + lora_dropout: float = field( + default=0.0, + metadata={"help":"Lora dropout."} + ) + max_memory_MB: int = field( + default=80000, + metadata={"help": "Free memory per gpu."} + ) + report_to: str = field( + default='none', + metadata={"help": "To use wandb or something else for reporting."} + ) + output_dir: str = field(default='./train/output', metadata={"help": 'The output dir for logs and checkpoints'}) + optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'}) + per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) + gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) + max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'}) + weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed + learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) + remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) + max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) + gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) + do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) + lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) + warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) + logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) + group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) + save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) + save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) + save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) + +@dataclass +class GenerationArguments: + # For more hyperparameters check: + # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + # Length arguments + max_new_tokens: Optional[int] = field( + default=256, + metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" + "if predict_with_generate is set."} + ) + min_new_tokens : Optional[int] = field( + default=None, + metadata={"help": "Minimum number of new tokens to generate."} + ) + + # Generation strategy + do_sample: Optional[bool] = field(default=False) + num_beams: Optional[int] = field(default=1) + num_beam_groups: Optional[int] = field(default=1) + penalty_alpha: Optional[float] = field(default=None) + use_cache: Optional[bool] = field(default=True) + + # Hyperparameters for logit manipulation + temperature: Optional[float] = field(default=1.0) + top_k: Optional[int] = field(default=50) + top_p: Optional[float] = field(default=1.0) + typical_p: Optional[float] = field(default=1.0) + diversity_penalty: Optional[float] = field(default=0.0) + repetition_penalty: Optional[float] = field(default=1.0) + length_penalty: Optional[float] = field(default=1.0) + no_repeat_ngram_size: Optional[int] = field(default=0) + +def find_all_linear_names(args, model): + cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +class SavePeftModelCallback(transformers.TrainerCallback): + def save_model(self, args, state, kwargs): + print('Saving PEFT checkpoint...') + if state.best_model_checkpoint is not None: + checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") + else: + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + + peft_model_path = os.path.join(checkpoint_folder, "adapter_model") + kwargs["model"].save_pretrained(peft_model_path) + + pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + os.remove(pytorch_model_path) + + def on_save(self, args, state, control, **kwargs): + self.save_model(args, state, kwargs) + return control + + def on_train_end(self, args, state, control, **kwargs): + def touch(fname, times=None): + with open(fname, 'a'): + os.utime(fname, times) + + touch(join(args.output_dir, 'completed')) + self.save_model(args, state, kwargs) + +def get_accelerate_model(args, checkpoint_dir): + + n_gpus = torch.cuda.device_count() + max_memory = f'{args.max_memory_MB}MB' + max_memory = {i: max_memory for i in range(n_gpus)} + device_map = "auto" + + # if we are in a distributed setting, we need to set the device map and max memory per device + if os.environ.get('LOCAL_RANK') is not None: + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + device_map = {'': local_rank} + max_memory = {'': max_memory[local_rank]} + + + if args.full_finetune: assert args.bits in [16, 32] + + print(f'loading base model {args.model_name_or_path}...') + compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + load_in_4bit=args.bits == 4, + load_in_8bit=args.bits == 8, + device_map=device_map, + max_memory=max_memory, + quantization_config=BitsAndBytesConfig( + load_in_4bit=args.bits == 4, + load_in_8bit=args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=args.double_quant, + bnb_4bit_quant_type=args.quant_type, + ), + torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)), + trust_remote_code=args.trust_remote_code, + use_auth_token=args.use_auth_token + ) + if compute_dtype == torch.float16 and args.bits == 4: + major, minor = torch.cuda.get_device_capability() + if major >= 8: + print('='*80) + print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16') + print('='*80) + + setattr(model, 'model_parallel', True) + setattr(model, 'is_parallelizable', True) + + model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) + + if not args.full_finetune: + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + if args.gradient_checkpointing: + model.gradient_checkpointing_enable() + + if not args.full_finetune: + if checkpoint_dir is not None: + print("Loading adapters from checkpoint.") + model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'), is_trainable=True) + else: + print(f'adding LoRA modules...') + modules = find_all_linear_names(args, model) + config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=modules, + lora_dropout=args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + return model + +def print_trainable_parameters(args, model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + if args.bits == 4: trainable_params /= 2 + print( + f"trainable params: {trainable_params} || " + f"all params: {all_param} || " + f"trainable: {100 * trainable_params / all_param}" + ) + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + +@dataclass +class DataCollatorForCausalLM(object): + tokenizer: transformers.PreTrainedTokenizer + source_max_len: int + target_max_len: int + train_on_source: bool + predict_with_generate: bool + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + # Extract elements + sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances] + targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] + # Tokenize + tokenized_sources_with_prompt = self.tokenizer( + sources, + max_length=self.source_max_len, + truncation=True, + add_special_tokens=False, + ) + tokenized_targets = self.tokenizer( + targets, + max_length=self.target_max_len, + truncation=True, + add_special_tokens=False, + ) + # Build the input and labels for causal LM + input_ids = [] + labels = [] + for tokenized_source, tokenized_target in zip( + tokenized_sources_with_prompt['input_ids'], + tokenized_targets['input_ids'] + ): + if not self.predict_with_generate: + input_ids.append(torch.tensor(tokenized_source + tokenized_target)) + if not self.train_on_source: + labels.append( + torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) + ) + else: + labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) + else: + input_ids.append(torch.tensor(tokenized_source)) + # Apply padding + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None + data_dict = { + 'input_ids': input_ids, + 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), + } + if labels is not None: + data_dict['labels'] = labels + return data_dict + +def extract_unnatural_instructions_data(examples, extract_reformulations=False): + out = { + 'input': [], + 'output': [], + } + for example_instances in examples['instances']: + for instance in example_instances: + out['input'].append(instance['instruction_with_input']) + out['output'].append(instance['output']) + if extract_reformulations: + for example_reformulations in examples['reformulations']: + if example_reformulations is not None: + for instance in example_reformulations: + out['input'].append(instance['instruction_with_input']) + out['output'].append(instance['output']) + return out + +ALPACA_PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response: " + ), +} + +SQL_PROMPT_DICT = { + "prompt_input": ( + "I want you to act as a SQL terminal in front of an example database. " + "Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n" + "###Instruction:\n{instruction}\n\n###Input:\n{input}\n\n###Response: " + ), + "prompt_no_input": ( + "I want you to act as a SQL terminal in front of an example database. " + "Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n" + "###Instruction:\n{instruction}\n\n### Response: " + ), +} + + +def extract_alpaca_dataset(example): + if example.get("input", "") != "": + prompt_format = ALPACA_PROMPT_DICT["prompt_input"] + else: + prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] + return {'input': prompt_format.format(**example)} + +def extract_sql_dataset(example): + if example.get("input", "") != "": + prompt_format = SQL_PROMPT_DICT["prompt_input"] + else: + prompt_format = SQL_PROMPT_DICT["prompt_no_input"] + return {'input': prompt_format.format(**example)} + +def local_dataset(dataset_name): + if dataset_name.endswith('.json'): + full_dataset = Dataset.from_json(path_or_paths=dataset_name) + elif dataset_name.endswith('.jsonl'): + full_dataset = Dataset.from_json(filename=dataset_name, format='jsonlines') + elif dataset_name.endswith('.csv'): + full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name)) + elif dataset_name.endswith('.tsv'): + full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t')) + else: + raise ValueError(f"Unsupported dataset format: {dataset_name}") + + split_dataset = full_dataset.train_test_split(test_size=0.1) + return split_dataset + +def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: + """ + Make dataset and collator for supervised fine-tuning. + Datasets are expected to have the following columns: { `input`, `output` } + + Available datasets to be selected with `dataset` argument: + - alpaca, 52002 examples + - alpaca cleaned, 51942 examples + - chip2 (OIG), 210289 examples + - self-instruct, 82612 examples + - hh-rlhf (Anthropic), 160800 examples + - longform, 23.7k examples + - oasst1 (OpenAssistant) primary message tree only, 9,846 examples + + Coming soon: + - unnatural instructions core, 66010 examples + - unnatural instructions full, 240670 examples + - alpaca-gpt4, 52002 examples + - unnatural-instructions-gpt4, 9000 examples + - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used) + - flan (FLAN v2), up to 20M examples available + - vicuna + + """ + def load_data(dataset_name): + if dataset_name == 'alpaca': + return load_dataset("tatsu-lab/alpaca") + elif dataset_name == 'alpaca-clean': + return load_dataset("yahma/alpaca-cleaned") + elif dataset_name == 'chip2': + return load_dataset("laion/OIG", data_files='unified_chip2.jsonl') + elif dataset_name == 'self-instruct': + return load_dataset("yizhongw/self_instruct", name='self_instruct') + elif dataset_name == 'hh-rlhf': + return load_dataset("Anthropic/hh-rlhf") + elif dataset_name == 'longform': + return load_dataset("akoksal/LongForm") + elif dataset_name == 'oasst1': + return load_dataset("timdettmers/openassistant-guanaco") + elif dataset_name == 'vicuna': + raise NotImplementedError("Vicuna data was not released.") + elif dataset_name == 'spider': + return load_dataset("json", data_files="train/sql_fintune_data.json") + else: + if os.path.exists(dataset_name): + try: + args.dataset_format = args.dataset_format if args.dataset_format else "input-output" + full_dataset = local_dataset(dataset_name) + return full_dataset + except: + raise ValueError(f"Error loading dataset from {dataset_name}") + else: + raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") + + def format_dataset(dataset, dataset_format): + if ( + dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or + (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean']) + ): + dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) + + elif dataset_format == 'spider': + dataset = dataset.map(extract_sql_dataset, remove_columns=['instruction']) + + elif dataset_format == 'chip2' or (dataset_format is None and args.dataset == 'chip2'): + dataset = dataset.map(lambda x: { + 'input': x['text'].split('\n: ')[0].replace(': ', ''), + 'output': x['text'].split('\n: ')[1], + }) + elif dataset_format == 'self-instruct' or (dataset_format is None and args.dataset == 'self-instruct'): + for old, new in [["prompt", "input"], ["completion", "output"]]: + dataset = dataset.rename_column(old, new) + elif dataset_format == 'hh-rlhf' or (dataset_format is None and args.dataset == 'hh-rlhf'): + dataset = dataset.map(lambda x: { + 'input': '', + 'output': x['chosen'] + }) + elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'): + dataset = dataset.map(lambda x: { + 'input': '', + 'output': x['text'], + }) + elif dataset_format == 'input-output': + # leave as is + pass + # Remove unused columns. + dataset = dataset.remove_columns( + [col for col in dataset.column_names['train'] if col not in ['input', 'output']] + ) + return dataset + + # Load dataset. + dataset = load_data(args.dataset) + dataset = format_dataset(dataset, args.dataset_format) + + # Split train/eval, reduce size + if args.do_eval or args.do_predict: + if 'eval' in dataset: + eval_dataset = dataset['eval'] + else: + print('Splitting train dataset in train and validation according to `eval_dataset_size`') + dataset = dataset["train"].train_test_split( + test_size=args.eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = dataset['test'] + if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples: + eval_dataset = eval_dataset.select(range(args.max_eval_samples)) + if args.group_by_length: + eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) + if args.do_train: + train_dataset = dataset['train'] + if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples: + train_dataset = train_dataset.select(range(args.max_train_samples)) + if args.group_by_length: + train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) + + data_collator = DataCollatorForCausalLM( + tokenizer=tokenizer, + source_max_len=args.source_max_len, + target_max_len=args.target_max_len, + train_on_source=args.train_on_source, + predict_with_generate=args.predict_with_generate, + ) + return dict( + train_dataset=train_dataset if args.do_train else None, + eval_dataset=eval_dataset if args.do_eval else None, + predict_dataset=eval_dataset if args.do_predict else None, + data_collator=data_collator + ) + +def get_last_checkpoint(checkpoint_dir): + if isdir(checkpoint_dir): + is_completed = exists(join(checkpoint_dir, 'completed')) + if is_completed: return None, True # already finished + max_step = 0 + for filename in os.listdir(checkpoint_dir): + if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'): + max_step = max(max_step, int(filename.replace('checkpoint-', ''))) + if max_step == 0: return None, is_completed # training started, but no checkpoint + checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}') + print(f"Found a previous checkpoint at: {checkpoint_dir}") + return checkpoint_dir, is_completed # checkpoint found! + return None, False # first training + + +def train(): + hfparser = transformers.HfArgumentParser(( + ModelArguments, DataArguments, TrainingArguments, GenerationArguments + )) + model_args, data_args, training_args, generation_args, extra_args = \ + hfparser.parse_args_into_dataclasses(return_remaining_strings=True) + training_args.generation_config = transformers.GenerationConfig(**vars(generation_args)) + args = argparse.Namespace( + **vars(model_args), **vars(data_args), **vars(training_args) + ) + + checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) + if completed_training: + print('Detected that training was already completed!') + + model = get_accelerate_model(args, checkpoint_dir) + + model.config.use_cache = False + print_trainable_parameters(args, model) + print('loaded model') + set_seed(args.seed) + + # Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + padding_side="right", + use_fast=False, # Fast tokenizer giving issues. + # tokenizer_type='llama' if 'llama' in args.model_name_or_path else None, # Needed for HF name change + use_auth_token=args.use_auth_token, + ) + if tokenizer._pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer): + # LLaMA tokenizer may not have correct special tokens set. + # Check and add them if missing to prevent them from being parsed into different tokens. + # Note that these are present in the vocabulary. + # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. + print('Adding special tokens.') + tokenizer.add_special_tokens({ + "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), + "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), + "unk_token": tokenizer.convert_ids_to_tokens( + model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id + ), + }) + data_module = make_data_module(tokenizer=tokenizer, args=args) + trainer = Seq2SeqTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + **{k:v for k,v in data_module.items() if k != 'predict_dataset'}, + ) + + # Callbacks + if not args.full_finetune: + trainer.add_callback(SavePeftModelCallback) + if args.do_mmlu_eval: + if args.mmlu_dataset == 'mmlu-zs': + mmlu_dataset = load_dataset("json", data_files={ + 'eval': 'data/mmlu/zero_shot_mmlu_val.json', + 'test': 'data/mmlu/zero_shot_mmlu_test.json', + }) + mmlu_dataset = mmlu_dataset.remove_columns('subject') + # MMLU Five-shot (Eval/Test only) + elif args.mmlu_dataset == 'mmlu' or args.mmlu_dataset == 'mmlu-fs': + mmlu_dataset = load_dataset("json", data_files={ + 'eval': 'data/mmlu/five_shot_mmlu_val.json', + 'test': 'data/mmlu/five_shot_mmlu_test.json', + }) + # mmlu_dataset = mmlu_dataset.remove_columns('subject') + mmlu_dataset = mmlu_dataset[args.mmlu_split] + if args.max_mmlu_samples is not None: + mmlu_dataset = mmlu_dataset.select(range(args.max_mmlu_samples)) + abcd_idx = [ + tokenizer("A", add_special_tokens=False).input_ids[0], + tokenizer("B", add_special_tokens=False).input_ids[0], + tokenizer("C", add_special_tokens=False).input_ids[0], + tokenizer("D", add_special_tokens=False).input_ids[0], + ] + accuracy = evaluate.load("accuracy") + class MMLUEvalCallback(transformers.TrainerCallback): + def on_evaluate(self, args, state, control, model, **kwargs): + data_loader = trainer.get_eval_dataloader(mmlu_dataset) + source_max_len = trainer.data_collator.source_max_len + trainer.data_collator.source_max_len = args.mmlu_source_max_len + trainer.model.eval() + preds, refs = [], [] + loss_mmlu = 0 + for batch in tqdm(data_loader, total=len(data_loader)): + (loss, logits, labels) = trainer.prediction_step(trainer.model,batch,prediction_loss_only=False,) + # There are two tokens, the output, and eos token. + for i, logit in enumerate(logits): + label_non_zero_id = (batch['labels'][i] != -100).nonzero()[0][0] + logit_abcd = logit[label_non_zero_id-1][abcd_idx] + preds.append(torch.argmax(logit_abcd).item()) + labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:,0] + refs += [abcd_idx.index(label) for label in labels.tolist()] + loss_mmlu += loss.item() + # Extract results by subject. + results = {'mmlu_loss':loss_mmlu/len(data_loader)} + subject = mmlu_dataset['subject'] + subjects = {s:{'refs':[], 'preds':[]} for s in set(subject)} + for s,p,r in zip(subject, preds, refs): + subjects[s]['preds'].append(p) + subjects[s]['refs'].append(r) + subject_scores = [] + for subject in subjects: + subject_score = accuracy.compute( + references=subjects[subject]['refs'], + predictions=subjects[subject]['preds'] + )['accuracy'] + results[f'mmlu_{args.mmlu_split}_accuracy_{subject}'] = subject_score + subject_scores.append(subject_score) + results[f'mmlu_{args.mmlu_split}_accuracy'] = np.mean(subject_scores) + trainer.log(results) + trainer.data_collator.source_max_len = source_max_len + + trainer.add_callback(MMLUEvalCallback) + + # Verifying the datatypes. + dtypes = {} + for _, p in model.named_parameters(): + dtype = p.dtype + if dtype not in dtypes: dtypes[dtype] = 0 + dtypes[dtype] += p.numel() + total = 0 + for k, v in dtypes.items(): total+= v + for k, v in dtypes.items(): + print(k, v, v/total) + + all_metrics = {"run_name": args.run_name} + # Training + if args.do_train: + logger.info("*** Train ***") + # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF. + # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not. + train_result = trainer.train() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + all_metrics.update(metrics) + # Evaluation + if args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + all_metrics.update(metrics) + # Prediction + if args.do_predict: + logger.info("*** Predict ***") + prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict") + prediction_metrics = prediction_output.metrics + predictions = prediction_output.predictions + predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) + predictions = tokenizer.batch_decode( + predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout: + for i, example in enumerate(data_module['predict_dataset']): + example['prediction_with_input'] = predictions[i].strip() + example['prediction'] = predictions[i].replace(example['input'], '').strip() + fout.write(json.dumps(example) + '\n') + print(prediction_metrics) + trainer.log_metrics("predict", prediction_metrics) + trainer.save_metrics("predict", prediction_metrics) + all_metrics.update(prediction_metrics) + + if (args.do_train or args.do_eval or args.do_predict): + with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: + fout.write(json.dumps(all_metrics)) + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/src/utils/merge_peft_adapters.py b/src/utils/merge_peft_adapters.py new file mode 100644 index 0000000..ed3585a --- /dev/null +++ b/src/utils/merge_peft_adapters.py @@ -0,0 +1,48 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel +import torch +import sys +import os +import argparse + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) +from pilot.configs.config import Config +from pilot.configs.model_config import LLM_MODEL_CONFIG +CFG = Config() +model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--base_model_name_or_path", type=str, default=model_path) + parser.add_argument("--peft_model_path", type=str, default="train/output/checkpoint-10/adapter_model") + parser.add_argument("--output_dir", type=str, default="train/output/merged_models/") + parser.add_argument("--device", type=str, default="cpu") + + return parser.parse_args() + +def main(): + args = get_args() + + print(f"Loading base model: {args.base_model_name_or_path}") + base_model = AutoModelForCausalLM.from_pretrained( + args.base_model_name_or_path, + return_dict=True, + torch_dtype=torch.float16, + trust_remote_code=True + ) + + print(f"Loading PEFT: {args.peft_model_path}") + model = PeftModel.from_pretrained(base_model, args.peft_model_path) + model.to(args.device) + print(f"Running merge_and_unload") + model = model.merge_and_unload() # https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L382 + + tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) + + model.save_pretrained(f"{args.output_dir}") + tokenizer.save_pretrained(f"{args.output_dir}") + print(f"Model saved to {args.output_dir}") + +if __name__ == "__main__" : + main() \ No newline at end of file