From 9c688e8086c353d29f3fdf7778a460a19da8c945 Mon Sep 17 00:00:00 2001 From: agombert Date: Tue, 5 Nov 2024 16:33:40 +0000 Subject: [PATCH 01/15] :bug: fix the dilation conv layer --- doclayout_yolo/nn/modules/g2l_crm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/doclayout_yolo/nn/modules/g2l_crm.py b/doclayout_yolo/nn/modules/g2l_crm.py index 62466b1..c558e91 100644 --- a/doclayout_yolo/nn/modules/g2l_crm.py +++ b/doclayout_yolo/nn/modules/g2l_crm.py @@ -32,11 +32,14 @@ def __init__(self, c, dilation, k, fuse="sum", shortcut=True): self.dcv = Conv(c, c, k=self.k, s=1) def dilated_conv(self, x, dilation): - act = self.dcv.act - bn = self.dcv.bn weight = self.dcv.conv.weight padding = dilation * (self.k//2) - return act(bn(F.conv2d(x, weight, stride=1, padding=padding, dilation=dilation))) + x = F.conv2d(x, weight, stride=1, padding=padding, dilation=dilation) + if hasattr(self.dcv, 'bn'): + x = self.dcv.bn(x) + if hasattr(self.dcv, 'act'): + x = self.dcv.act(x) + return x def forward(self, x): """'forward()' applies the YOLO FPN to input data.""" From 5a380a96137d8fb3238b5d1b2b4cf4b23543f17f Mon Sep 17 00:00:00 2001 From: agombert Date: Tue, 5 Nov 2024 16:33:56 +0000 Subject: [PATCH 02/15] :arrow_up: add the huggingface_hub dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c2ef158..be081ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ dependencies = [ "pandas>=1.1.4", "seaborn>=0.11.0", # plotting "albumentations>=1.4.11", + "huggingface_hub>=0.26.2", ] # Optional dependencies ------------------------------------------------------------------------------------------------ From cf7288f4b60a795892ebd89bcde8ebe4a71d10c1 Mon Sep 17 00:00:00 2001 From: agombert Date: Tue, 5 Nov 2024 16:34:10 +0000 Subject: [PATCH 03/15] :bug: comment the yolov8 model line --- doclayout_yolo/utils/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doclayout_yolo/utils/checks.py b/doclayout_yolo/utils/checks.py index e378281..f8cbac7 100644 --- a/doclayout_yolo/utils/checks.py +++ b/doclayout_yolo/utils/checks.py @@ -650,7 +650,7 @@ def amp_allclose(m, im): try: from doclayout_yolo import YOLO - assert amp_allclose(YOLO("yolov8n.pt"), im) + #assert amp_allclose(YOLO("yolov8n.pt"), im) LOGGER.info(f"{prefix}checks passed ✅") except ConnectionError: LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}") From 27e69760173931fd68a84d5867dd3fb76ae203c3 Mon Sep 17 00:00:00 2001 From: agombert Date: Fri, 20 Dec 2024 14:43:29 +0100 Subject: [PATCH 04/15] :arrow_up: limit albumentations to 1.4.21 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index be081ae..0767c01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dependencies = [ "thop>=0.1.1", # FLOPs computation "pandas>=1.1.4", "seaborn>=0.11.0", # plotting - "albumentations>=1.4.11", + "albumentations<=1.4.21", "huggingface_hub>=0.26.2", ] From da7fa594b56ddc4a3bfe967873a4f49b5b52d5fd Mon Sep 17 00:00:00 2001 From: agombert Date: Mon, 23 Dec 2024 10:34:17 +0100 Subject: [PATCH 05/15] :arrow_up: update to have albumentations 1.4.18 opencv-python 4.10.0.84 opencv-python-headless 4.10.0.84 --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0767c01..84c715b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,8 @@ classifiers = [ # Required dependencies ------------------------------------------------------------------------------------------------ dependencies = [ "matplotlib>=3.3.0", - "opencv-python>=4.6.0", + "opencv-python==4.10.0.84", + "opencv-python-headless==4.10.0.84", "pillow>=7.1.2", "pyyaml>=5.3.1", "requests>=2.23.0", @@ -78,7 +79,7 @@ dependencies = [ "thop>=0.1.1", # FLOPs computation "pandas>=1.1.4", "seaborn>=0.11.0", # plotting - "albumentations<=1.4.21", + "albumentations==1.4.18", "huggingface_hub>=0.26.2", ] @@ -121,7 +122,7 @@ logging = [ extra = [ "hub-sdk>=0.0.5", # Ultralytics HUB "ipython", # interactive notebook - "albumentations>=1.0.3", # training augmentations + "albumentations==1.4.18", # training augmentations "pycocotools>=2.0.7", # COCO mAP ] From 60e72494391f71f57b0e8543c510639d70a6be3c Mon Sep 17 00:00:00 2001 From: agombert Date: Tue, 24 Dec 2024 12:04:25 +0100 Subject: [PATCH 06/15] :arrow_up: update dependencies as in root repo --- pyproject.toml | 53 +++++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84c715b..e1c1b07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,15 +18,14 @@ # Documentation: # For comprehensive documentation and usage instructions, visit: https://docs.ultralytics.com -[build-system] +[build-system][build-system] requires = ["setuptools>=43.0.0", "wheel"] build-backend = "setuptools.build_meta" -# Project settings ----------------------------------------------------------------------------------------------------- [project] name = "doclayout_yolo" dynamic = ["version"] -description = "DocLayout-YOLO: an effecient and robust document layout analysis method." +description = "DocLayout-YOLO: an efficient and robust document layout analysis method." readme = "README.md" requires-python = ">=3.8" license = { "text" = "AGPL-3.0" } @@ -62,28 +61,26 @@ classifiers = [ "Operating System :: Microsoft :: Windows", ] -# Required dependencies ------------------------------------------------------------------------------------------------ dependencies = [ - "matplotlib>=3.3.0", - "opencv-python==4.10.0.84", + "matplotlib==3.9.0", + "opencv-python==4.9.0.80", "opencv-python-headless==4.10.0.84", - "pillow>=7.1.2", - "pyyaml>=5.3.1", - "requests>=2.23.0", - "scipy>=1.4.1", - "torch>=2.0.1", - "torchvision>=0.15.2", - "tqdm>=4.64.0", # progress bars - "psutil", # system utilization - "py-cpuinfo", # display CPU info - "thop>=0.1.1", # FLOPs computation - "pandas>=1.1.4", - "seaborn>=0.11.0", # plotting - "albumentations==1.4.18", - "huggingface_hub>=0.26.2", + "pillow==10.3.0", + "pyyaml==6.0.1", + "requests==2.28.2", + "scipy==1.13.1", + "torch==2.5.0", + "torchvision==0.20.0", + "tqdm==4.65.2", # progress bars + "psutil==5.9.8", # system utilization + "py-cpuinfo==9.0.0", # display CPU info + "thop==0.1.1-2209072238", # FLOPs computation + "pandas==2.2.2", + "seaborn==0.13.2", # plotting + "albumentations==1.4.11", + "huggingface_hub==0.23.2", ] -# Optional dependencies ------------------------------------------------------------------------------------------------ [project.optional-dependencies] dev = [ "ipython", @@ -99,7 +96,7 @@ dev = [ "mkdocs-ultralytics-plugin>=0.0.44", # for meta descriptions and images, dates and authors ] export = [ - "onnx>=1.12.0", # ONNX export + "onnx==1.14.0", # ONNX export "coremltools>=7.0; platform_system != 'Windows' and python_version <= '3.11'", # CoreML supported on macOS and Linux "openvino>=2024.0.0", # OpenVINO export "tensorflow<=2.13.1; python_version <= '3.11'", # TF bug https://github.com/ultralytics/ultralytics/issues/5161 @@ -110,19 +107,15 @@ explorer = [ "duckdb<=0.9.2", # SQL queries, duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181 "streamlit", # visualizing with GUI ] -# tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos) -# tflite-support # for TFLite model metadata -# nvidia-pyindex # TensorRT export -# nvidia-tensorrt # TensorRT export logging = [ "comet", # https://docs.ultralytics.com/integrations/comet/ - "tensorboard>=2.13.0", + "tensorboard==2.17.0", "dvclive>=2.12.0", ] extra = [ "hub-sdk>=0.0.5", # Ultralytics HUB "ipython", # interactive notebook - "albumentations==1.4.18", # training augmentations + "albumentations==1.4.11", # training augmentations "pycocotools>=2.0.7", # COCO mAP ] @@ -130,8 +123,7 @@ extra = [ yolo = "doclayout_yolo.cfg:entrypoint" doclayout_yolo = "doclayout_yolo.cfg:entrypoint" -# Tools settings ------------------------------------------------------------------------------------------------------- -[tool.setuptools] # configuration specific to the `setuptools` build backend. +[tool.setuptools] packages = { find = { where = ["."], include = ["doclayout_yolo", "doclayout_yolo.*"] } } package-data = { "doclayout_yolo" = ["**/*.yaml"], "doclayout_yolo.assets" = ["*.jpg"] } @@ -145,7 +137,6 @@ markers = [ ] norecursedirs = [".git", "dist", "build"] - [tool.coverage.run] source = ["doclayout_yolo/"] data_file = "tests/.coverage" From 476df5cfe1a69934ac68f44fee11b07194e7bbed Mon Sep 17 00:00:00 2001 From: agombert Date: Tue, 24 Dec 2024 16:24:27 +0100 Subject: [PATCH 07/15] :arrow_up: add hugginface_hub and datasets from hf --- pyproject.toml | 55 +++++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e1c1b07..4c06829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,15 @@ # Documentation: # For comprehensive documentation and usage instructions, visit: https://docs.ultralytics.com -[build-system][build-system] +[build-system] requires = ["setuptools>=43.0.0", "wheel"] build-backend = "setuptools.build_meta" +# Project settings ----------------------------------------------------------------------------------------------------- [project] name = "doclayout_yolo" dynamic = ["version"] -description = "DocLayout-YOLO: an efficient and robust document layout analysis method." +description = "DocLayout-YOLO: an effecient and robust document layout analysis method." readme = "README.md" requires-python = ">=3.8" license = { "text" = "AGPL-3.0" } @@ -61,26 +62,28 @@ classifiers = [ "Operating System :: Microsoft :: Windows", ] +# Required dependencies ------------------------------------------------------------------------------------------------ dependencies = [ - "matplotlib==3.9.0", - "opencv-python==4.9.0.80", - "opencv-python-headless==4.10.0.84", - "pillow==10.3.0", - "pyyaml==6.0.1", - "requests==2.28.2", - "scipy==1.13.1", - "torch==2.5.0", - "torchvision==0.20.0", - "tqdm==4.65.2", # progress bars - "psutil==5.9.8", # system utilization - "py-cpuinfo==9.0.0", # display CPU info - "thop==0.1.1-2209072238", # FLOPs computation - "pandas==2.2.2", - "seaborn==0.13.2", # plotting - "albumentations==1.4.11", - "huggingface_hub==0.23.2", + "matplotlib>=3.3.0", + "opencv-python>=4.6.0", + "pillow>=7.1.2", + "pyyaml>=5.3.1", + "requests>=2.23.0", + "scipy>=1.4.1", + "torch>=2.0.1", + "torchvision>=0.15.2", + "tqdm>=4.64.0", # progress bars + "psutil", # system utilization + "py-cpuinfo", # display CPU info + "thop>=0.1.1", # FLOPs computation + "pandas>=1.1.4", + "seaborn>=0.11.0", # plotting + "albumentations>=1.4.11", + "huggingface_hub>=0.23.2", + "datasets>=2.14.4", ] +# Optional dependencies ------------------------------------------------------------------------------------------------ [project.optional-dependencies] dev = [ "ipython", @@ -96,7 +99,7 @@ dev = [ "mkdocs-ultralytics-plugin>=0.0.44", # for meta descriptions and images, dates and authors ] export = [ - "onnx==1.14.0", # ONNX export + "onnx>=1.12.0", # ONNX export "coremltools>=7.0; platform_system != 'Windows' and python_version <= '3.11'", # CoreML supported on macOS and Linux "openvino>=2024.0.0", # OpenVINO export "tensorflow<=2.13.1; python_version <= '3.11'", # TF bug https://github.com/ultralytics/ultralytics/issues/5161 @@ -107,15 +110,19 @@ explorer = [ "duckdb<=0.9.2", # SQL queries, duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181 "streamlit", # visualizing with GUI ] +# tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos) +# tflite-support # for TFLite model metadata +# nvidia-pyindex # TensorRT export +# nvidia-tensorrt # TensorRT export logging = [ "comet", # https://docs.ultralytics.com/integrations/comet/ - "tensorboard==2.17.0", + "tensorboard>=2.13.0", "dvclive>=2.12.0", ] extra = [ "hub-sdk>=0.0.5", # Ultralytics HUB "ipython", # interactive notebook - "albumentations==1.4.11", # training augmentations + "albumentations>=1.0.3", # training augmentations "pycocotools>=2.0.7", # COCO mAP ] @@ -123,7 +130,8 @@ extra = [ yolo = "doclayout_yolo.cfg:entrypoint" doclayout_yolo = "doclayout_yolo.cfg:entrypoint" -[tool.setuptools] +# Tools settings ------------------------------------------------------------------------------------------------------- +[tool.setuptools] # configuration specific to the `setuptools` build backend. packages = { find = { where = ["."], include = ["doclayout_yolo", "doclayout_yolo.*"] } } package-data = { "doclayout_yolo" = ["**/*.yaml"], "doclayout_yolo.assets" = ["*.jpg"] } @@ -137,6 +145,7 @@ markers = [ ] norecursedirs = [".git", "dist", "build"] + [tool.coverage.run] source = ["doclayout_yolo/"] data_file = "tests/.coverage" From 88eb6582cb68897c9ad2b7d3fb248f17a0a38412 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:40:39 +0000 Subject: [PATCH 08/15] :wrench: settings for hyper parameters and paths --- settings.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 settings.py diff --git a/settings.py b/settings.py new file mode 100644 index 0000000..8e6098d --- /dev/null +++ b/settings.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from pathlib import Path + +@dataclass +class LayoutParserTrainingSettings: + from_dataset_repo: str = "agomberto/historical-layout" + local_data_dir: str = "/home/ubuntu/datasets/data" + local_model_dir: str = "/home/ubuntu/models" + from_model_repo: str = "juliozhao/DocLayout-YOLO-DocStructBench" + from_model_name: str = "doclayout_yolo_docstructbench_imgsz1024.pt" + pushed_model_name: str = "my_ft_model.pt" + pushed_model_repo: str = "agomberto/historical-layout-ft-test" + local_ft_model_dir: str = "/home/ubuntu/yolo_ft" + + # hyperparameters + batch_size: int = 8 + epochs: int = 5 + image_size: int = 1024 + lr0: float = 0.001 + optimizer: str = "Adam" + base_model: str = "m-doclayout" + patience: int = 5 + + # Optional training parameters (with defaults) + warmup_epochs: float = 3.0 + momentum: float = 0.9 + mosaic: float = 1.0 + workers: int = 4 + device: str = "0" + val_period: int = 1 + save_period: int = 10 + + @property + def local_ft_model_name(self) -> str: + """Get the path to the fine-tuned model""" + name = (f"yolov10{self.base_model}_{self.local_data_dir}_" + f"epoch{self.epochs}_imgsz{self.image_size}_" + f"bs{self.batch_size}_pretrain_docstruct") + return str(Path(self.local_ft_model_dir) / name / "weights/best.pt") From 57996ef8735fd532af557d0e11e48812ec002238 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:41:07 +0000 Subject: [PATCH 09/15] :rocket: get data, train and push model to HF --- prepare_data_and_train.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100755 prepare_data_and_train.sh diff --git a/prepare_data_and_train.sh b/prepare_data_and_train.sh new file mode 100755 index 0000000..807e3c3 --- /dev/null +++ b/prepare_data_and_train.sh @@ -0,0 +1,10 @@ +# Connect to HuggingFace hub if not already connected +if [ ! -f ~/.huggingface/token ]; then + huggingface-cli login +fi + +# Prepare data for training +python prepare_data_for_training.py + +# Train the model +python training_wrapper.py \ No newline at end of file From 28cce82fced11a28ad51d3c0302f1772013ade0f Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:41:30 +0000 Subject: [PATCH 10/15] :tada: script to get data from HF --- prepare_data_for_training.py | 80 ++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 prepare_data_for_training.py diff --git a/prepare_data_for_training.py b/prepare_data_for_training.py new file mode 100644 index 0000000..b1fca6a --- /dev/null +++ b/prepare_data_for_training.py @@ -0,0 +1,80 @@ +import os +import uuid +from datasets import load_dataset +from huggingface_hub import hf_hub_download +from settings import LayoutParserTrainingSettings + + +def prepare_data(settings: LayoutParserTrainingSettings): + """Prepare data for YOLO training""" + + # Load dataset + dataset = load_dataset(settings.from_dataset_repo) + + # Convert to pandas for splitting + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + # Create directory structure + dirs = [ + os.path.join(settings.local_data_dir, "images", "train"), + os.path.join(settings.local_data_dir, "images", "val"), + os.path.join(settings.local_data_dir, "labels", "train"), + os.path.join(settings.local_data_dir, "labels", "val"), + ] + + # Create directories + for dir_path in dirs: + os.makedirs(dir_path, exist_ok=True) + + # Process each split + for split_name, split_data in zip(["train", "val"], [train_dataset, test_dataset]): + for item in split_data: + # Generate unique filename + filename = str(uuid.uuid4()) + + # Save image + image_path = os.path.join( + settings.local_data_dir, "images", split_name, f"{filename}.jpg" + ) + item["image"].save(image_path) + + # Save labels + label_path = os.path.join( + settings.local_data_dir, "labels", split_name, f"{filename}.txt" + ) + with open(label_path, "w") as f: + for category, bbox in zip( + item["objects"]["categories"], item["objects"]["bbox"] + ): + line = f"{category} {' '.join(map(str, bbox))}\n" + f.write(line) + + # Download YAML config + hf_hub_download( + repo_id=settings.from_dataset_repo, + filename="config.yaml", + repo_type="dataset", + local_dir=settings.local_data_dir, + ) + + # Download pretrained model + hf_hub_download( + repo_id=settings.from_model_repo, + filename=settings.from_model_name, + repo_type="model", + local_dir=settings.local_model_dir, + ) + + print(f"Data prepared in {settings.local_data_dir}") + print( + f"Train images: {len(os.listdir(os.path.join(settings.local_data_dir, 'images', 'train')))}" + ) + print( + f"Val images: {len(os.listdir(os.path.join(settings.local_data_dir, 'images', 'val')))}" + ) + + +if __name__ == "__main__": + settings = LayoutParserTrainingSettings() + prepare_data(settings) From deee3552e8938b940e061742acf71e707ae56a31 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:41:51 +0000 Subject: [PATCH 11/15] :rocket: wrapper to train and push model --- training_wrapper.py | 120 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 training_wrapper.py diff --git a/training_wrapper.py b/training_wrapper.py new file mode 100644 index 0000000..763f962 --- /dev/null +++ b/training_wrapper.py @@ -0,0 +1,120 @@ +from pathlib import Path +import argparse +from settings import LayoutParserTrainingSettings +from doclayout_yolo import YOLOv10 +from datetime import datetime +import logging +from huggingface_hub import HfApi + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +def train_model(settings: LayoutParserTrainingSettings): + """ + Train YOLOv10 model using settings from LayoutParserTrainingSettings + """ + # Load pretrained model + model_path = Path(settings.local_model_dir) / settings.from_model_name + model = YOLOv10(str(model_path)) + pretrain_name = "docstruct" if "docstruct" in settings.from_model_name else "unknown" + + # Construct run name + name = (f"yolov10{settings.base_model}_{settings.local_data_dir}_" + f"epoch{settings.epochs}_imgsz{settings.image_size}_" + f"bs{settings.batch_size}_pretrain_{pretrain_name}") + + # Train model + results = model.train( + data=f'{settings.local_data_dir}/config.yaml', + epochs=settings.epochs, + warmup_epochs=settings.warmup_epochs, + lr0=settings.lr0, + optimizer=settings.optimizer, + momentum=settings.momentum, + imgsz=settings.image_size, + mosaic=settings.mosaic, + batch=settings.batch_size, + device=settings.device, + workers=settings.workers, + plots=True, + exist_ok=False, + val=True, + val_period=settings.val_period, + resume=False, + save_period=settings.save_period, + patience=settings.patience, + project=settings.local_ft_model_dir, + name=name, + ) + + return results + +def push_to_hub( + settings: LayoutParserTrainingSettings, + commit_message=None, +): + """Push trained model to Hugging Face Hub""" + + # Initialize Hugging Face API + api = HfApi() + + # Create repo if it doesn't exist + try: + api.create_repo(repo_id=settings.pushed_model_repo, exist_ok=True, repo_type="model") + except Exception as e: + print(f"Repository creation failed: {e}") + return + + # Default commit message + if commit_message is None: + commit_message = ( + f"Upload model - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ) + + # Upload the model file + try: + api.upload_file( + path_or_fileobj=settings.local_ft_model_name, + path_in_repo=settings.pushed_model_name, + repo_id=settings.pushed_model_repo, + commit_message=commit_message, + ) + print(f"Model successfully uploaded to {settings.pushed_model_repo}") + except Exception as e: + print(f"Upload failed: {e}") + + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description='Train and optionally push YOLOv10 model') + parser.add_argument('--push', action='store_true', help='Push model to HuggingFace Hub after training') + parser.add_argument('--commit-message', type=str, + help='Custom commit message for model push (default: timestamp)') + args = parser.parse_args() + + # Initialize settings + settings = LayoutParserTrainingSettings() + + try: + # Train model + logger.info(f"Starting training with batch size {settings.batch_size} and {settings.epochs} epochs") + results = train_model(settings) + logger.info(f"Training completed. Model saved at: {settings.local_ft_model_name}") + + # Push model if requested + if args.push: + logger.info("Pushing model to HuggingFace Hub...") + commit_message = args.commit_message or f"Model trained on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + push_to_hub( + settings=settings, + commit_message=commit_message + ) + logger.info(f"Model successfully pushed to {settings.pushed_model_repo}") + + except Exception as e: + logger.error(f"Error occurred: {str(e)}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file From 36730ee32cda95ac1b6e498bffcfde9e24a38cf0 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:48:17 +0000 Subject: [PATCH 12/15] :beers: add push option --- prepare_data_and_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prepare_data_and_train.sh b/prepare_data_and_train.sh index 807e3c3..060aa41 100755 --- a/prepare_data_and_train.sh +++ b/prepare_data_and_train.sh @@ -7,4 +7,4 @@ fi python prepare_data_for_training.py # Train the model -python training_wrapper.py \ No newline at end of file +python training_wrapper.py --push \ No newline at end of file From e93f24dc63526b311e144b4d78367ddc513d8988 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:48:27 +0000 Subject: [PATCH 13/15] :wrench: add plot setting --- settings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/settings.py b/settings.py index 8e6098d..b382426 100644 --- a/settings.py +++ b/settings.py @@ -29,6 +29,7 @@ class LayoutParserTrainingSettings: device: str = "0" val_period: int = 1 save_period: int = 10 + plots: bool = False @property def local_ft_model_name(self) -> str: From da6cc3efc26faeee7a91055e008f27c5c6666ae8 Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:48:37 +0000 Subject: [PATCH 14/15] :bug: fix the way to call function --- training_wrapper.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/training_wrapper.py b/training_wrapper.py index 763f962..3a36176 100644 --- a/training_wrapper.py +++ b/training_wrapper.py @@ -37,7 +37,7 @@ def train_model(settings: LayoutParserTrainingSettings): batch=settings.batch_size, device=settings.device, workers=settings.workers, - plots=True, + plots=settings.plots, exist_ok=False, val=True, val_period=settings.val_period, @@ -85,16 +85,7 @@ def push_to_hub( print(f"Upload failed: {e}") -def main(): - # Parse command line arguments - parser = argparse.ArgumentParser(description='Train and optionally push YOLOv10 model') - parser.add_argument('--push', action='store_true', help='Push model to HuggingFace Hub after training') - parser.add_argument('--commit-message', type=str, - help='Custom commit message for model push (default: timestamp)') - args = parser.parse_args() - - # Initialize settings - settings = LayoutParserTrainingSettings() +def main(settings: LayoutParserTrainingSettings, push: bool = False, commit_message: str = None): try: # Train model @@ -108,7 +99,8 @@ def main(): commit_message = args.commit_message or f"Model trained on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" push_to_hub( settings=settings, - commit_message=commit_message + commit_message=commit_message, + private=True, ) logger.info(f"Model successfully pushed to {settings.pushed_model_repo}") @@ -117,4 +109,12 @@ def main(): raise if __name__ == "__main__": - main() \ No newline at end of file + + parser = argparse.ArgumentParser(description='Train and optionally push YOLOv10 model') + parser.add_argument('--push', action='store_true', help='Push model to HuggingFace Hub after training') + parser.add_argument('--commit-message', type=str, + help='Custom commit message for model push (default: timestamp)') + args = parser.parse_args() + + settings = LayoutParserTrainingSettings() + main(settings, args.push, args.commit_message) From 322276765af7cc1a3ae735886b7a2a4630ccfbfb Mon Sep 17 00:00:00 2001 From: Arnault Gombert Date: Tue, 24 Dec 2024 16:59:28 +0000 Subject: [PATCH 15/15] :bug: fix private setting for repo --- training_wrapper.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/training_wrapper.py b/training_wrapper.py index 3a36176..808b852 100644 --- a/training_wrapper.py +++ b/training_wrapper.py @@ -61,7 +61,10 @@ def push_to_hub( # Create repo if it doesn't exist try: - api.create_repo(repo_id=settings.pushed_model_repo, exist_ok=True, repo_type="model") + api.create_repo(repo_id=settings.pushed_model_repo, + exist_ok=True, + repo_type="model", + private=True) except Exception as e: print(f"Repository creation failed: {e}") return @@ -99,8 +102,7 @@ def main(settings: LayoutParserTrainingSettings, push: bool = False, commit_mess commit_message = args.commit_message or f"Model trained on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" push_to_hub( settings=settings, - commit_message=commit_message, - private=True, + commit_message=commit_message ) logger.info(f"Model successfully pushed to {settings.pushed_model_repo}")