From 06781bc1e2cc3a1471e9c8242cd75e22fc1b8023 Mon Sep 17 00:00:00 2001 From: AG Date: Thu, 26 Sep 2024 02:17:06 -0500 Subject: [PATCH] Improved resumption --- backup/processing_progress.txt | Bin 20 -> 61 bytes train.py | 204 +++++++++++++++++++-------------- 2 files changed, 115 insertions(+), 89 deletions(-) diff --git a/backup/processing_progress.txt b/backup/processing_progress.txt index 8ba42500dd4a9284100062c6f90b15f749dce20a..433979efbd36dfcbbb9559920a13a85e06b6b2d1 100644 GIT binary patch literal 61 zcmZo*nQFuU0ku;!dN_+S5{pveGgB(2cr)|}LU^eaiMa(isZi0r diff --git a/train.py b/train.py index f6fded0..1743679 100644 --- a/train.py +++ b/train.py @@ -1,22 +1,22 @@ -# Import necessary modules +import pickle import os -import sys +import gc import shutil from tqdm import tqdm import signal import datasets import logging +import sys +import asyncio from lib.process_predictive_words import main as process_predictive_words from lib.process_context_words import main as process_context_words from lib.finish_filing import main as finish_filing from lib.create_dictionary import create_batch from lib.merge_batches import main as merge_batches -import asyncio -import gc from lib.constants import PRUNE_FREQUENCY, TARGET_DICTIONARY_COUNT, TOTAL_WORD_COUNT import argparse # Import argparse for command-line parsing -# Define a flag to indicate when an interrupt has been caught +# Global flag for graceful exit interrupted = False def signal_handler(sig, frame): @@ -24,96 +24,122 @@ def signal_handler(sig, frame): interrupted = True print("Graceful exit request received.") -# Register the signal handler +# Signal handler for graceful exit signal.signal(signal.SIGINT, signal_handler) -DEFAULT_TREE_STORE ={} - -async def save_position(progress_file, current_position, word_count, tree_store): - # Every now and then save our progress. - print(f"Saving the current position of %s" % current_position) - - # Save the current progress (file position) - with open(progress_file, 'w') as f: - f.write(f"{str(current_position)},{str(word_count)}") - - print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY) - # TODO This was causing too many problems. - # await create_batch(tree_store, TARGET_DICTIONARY_COUNT) - return DEFAULT_TREE_STORE +DEFAULT_TREE_STORE = {} + +async def load_progress(progress_file): + """Try to load progress using the new method (state_dict) or fall back to the old method.""" + if os.path.exists(progress_file): + try: + # Try to load the progress as a pickle object (new method) + with open(progress_file, 'rb') as f: + state_dict, word_count = pickle.load(f) + print(f"Loaded progress using state_dict with word count {word_count}") + return state_dict, word_count + except (pickle.UnpicklingError, EOFError): + # Fallback to the old method if pickle loading fails (old method) + with open(progress_file, 'r') as f: + start_position_str, word_length_str = f.read().strip().split(',') + start_position = int(start_position_str) + word_count = int(word_length_str) + print(f"Loaded progress using old method from position {start_position} with word count {word_count}") + return start_position, word_count + return None, 0 # No progress file found + +async def save_position(progress_file, dataset, word_count, tree_store): + """Always save progress using the new state_dict method.""" + # Always create the state_dict, even if resuming from an old format + state_dict = dataset.state_dict() + with open(progress_file, 'wb') as f: + pickle.dump((state_dict, word_count), f) + print(f"Saved state_dict and word count {word_count}") + await create_batch(tree_store, TARGET_DICTIONARY_COUNT) + + return DEFAULT_TREE_STORE async def main(retain=False): - tree_store = DEFAULT_TREE_STORE - if not retain and os.path.exists('training'): - shutil.rmtree('training') - print("Previous training data cleared.") - - training_path = 'training' - os.makedirs(training_path, exist_ok=True) - - # Load dataset from Hugging Face datasets - datasets.logging.set_verbosity(datasets.logging.WARNING) - logging.getLogger('fsspec').setLevel(logging.WARNING) - logging.getLogger('urllib3').setLevel(logging.WARNING) - dataset = datasets.load_dataset('oscar-corpus/OSCAR-2201', language='en', split='train', streaming=True, trust_remote_code=True) - - # Initialize start_position and word_length to 0 - start_position = 0 - word_count = 0 - - # Check if the --retain flag is used and if the progress file exists - if retain and os.path.exists('training/processing_progress.txt'): - with open('training/processing_progress.txt', 'r') as f: - # Read the line and split it by the comma to get both values - start_position_str, word_length_str = f.read().strip().split(',') - # Convert the string values to integers - start_position = int(start_position_str) - word_count = int(word_length_str) - print(f"Resuming from position {start_position} with {word_count} total words processed.") - - pbar = tqdm(total=TOTAL_WORD_COUNT, unit='word', desc="Processing dataset", position=1) - pbar.update(word_count) - for i, entry in enumerate(dataset.skip(start_position)): - if i + start_position < start_position: - pbar.display(f"Skipping ahead from {i + start_position} to {start_position}", 1) - continue # Skip to the saved position - text = entry['text'] # Extract text from dataset entry - words = text.split() - - pbar.update(len(words)) - - # Replace reserved characters as before - words = [word.replace("score", "\sscore") for word in words] - words = [word.replace("prediction", "\sprediction") for word in words] - - # Process words three at a time with shifting window - for j in range(len(words) - 2): - word_count += 1 - if interrupted: - print("Script will terminate when done.") - sys.exit(0) - - context_words = process_context_words(words, j) - predictive_words = process_predictive_words(words, j) - - if not predictive_words: - continue - - tree_store = finish_filing(tree_store, context_words, predictive_words) - - if (word_count + 1) % PRUNE_FREQUENCY == 0: - # Save position and prune every PRUNE_FREQUENCY entries - tree_store = await save_position('training/processing_progress.txt', i + start_position + 1, word_count, tree_store) - gc.collect() - - if (word_count + 1) % (PRUNE_FREQUENCY * 100) == 0: - await merge_batches() - - await create_batch(tree_store, TARGET_DICTIONARY_COUNT) + tree_store = DEFAULT_TREE_STORE + training_path = 'training' + + # Clear previous training data if not retaining + if not retain and os.path.exists(training_path): + shutil.rmtree(training_path) + print("Previous training data cleared.") + + os.makedirs(training_path, exist_ok=True) + + # Load dataset from Hugging Face datasets + datasets.logging.set_verbosity(datasets.logging.WARNING) + logging.getLogger('fsspec').setLevel(logging.WARNING) + logging.getLogger('urllib3').setLevel(logging.WARNING) + dataset = datasets.load_dataset('oscar-corpus/OSCAR-2201', language='en', split='train', streaming=True, trust_remote_code=True) + + word_count = 0 + start_position = 0 + state_dict = None + + # Load previous progress (either old or new format) + if retain: + state_dict, word_count = await load_progress('training/processing_progress.txt') + if isinstance(state_dict, dict): + dataset.load_state_dict(state_dict) + else: + # Resume using old method; we still skip to start position but will save with state_dict + start_position = state_dict if isinstance(state_dict, int) else 0 + + # Initialize progress bar + pbar = tqdm(total=TOTAL_WORD_COUNT, unit='word', desc="Processing dataset", position=1) + pbar.update(word_count) + + # Processing dataset + for i, entry in enumerate(dataset.skip(start_position)): + if interrupted: + print("Script will terminate when done.") + sys.exit(0) + + # Extract text and process words + text = entry['text'] + words = text.split() + + # Update the progress bar with the number of words processed + pbar.update(len(words)) + + # Replace reserved characters + words = [word.replace("score", "\sscore") for word in words] + words = [word.replace("prediction", "\sprediction") for word in words] + + # Process words three at a time with a shifting window + for j in range(len(words) - 2): + word_count += 1 + + # Get context and predictive words + context_words = process_context_words(words, j) + predictive_words = process_predictive_words(words, j) + + if not predictive_words: + continue + + # File the words + tree_store = finish_filing(tree_store, context_words, predictive_words) + + # Save position and prune periodically + if (word_count + 1) % PRUNE_FREQUENCY == 0: + tree_store = await save_position('training/processing_progress.txt', dataset, word_count, tree_store) + gc.collect() + + # Silencing for now. Creating too many problems. + # Merge batches periodically + # if (word_count + 1) % (PRUNE_FREQUENCY * 100) == 0: + # await merge_batches() + + # Final batch creation after processing is complete + await create_batch(tree_store, TARGET_DICTIONARY_COUNT) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Training script with position retain functionality.') parser.add_argument('--retain', action='store_true', help='Retain and resume from last saved position.') args = parser.parse_args() - - asyncio.run(main(retain=args.retain)) + + asyncio.run(main(retain=args.retain)) \ No newline at end of file