diff --git a/.gitignore b/.gitignore index 689efbbb4..ac0b3c15f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,10 @@ PytorchWildlife.egg-info/ *setup/* *dist/* *VSCodeCounter* -*build/* \ No newline at end of file +*build/* +*bee* +*Bee* +*debug* +*annotations.csv +*cropped_resized* +*log* \ No newline at end of file diff --git a/PW_FT_classification/README.md b/PW_FT_classification/README.md index fc9d79267..9090c6939 100644 --- a/PW_FT_classification/README.md +++ b/PW_FT_classification/README.md @@ -88,8 +88,9 @@ Before training your model, you need to configure the training and data paramete - `num_workers`: Number of subprocesses to use for data loading. - **Model Parameters:** + - `num_classes`: The number of classes. - `model_name`: The name of the model architecture to use. The current version only supports PlainResNetClassifier. - - `num_layers`: Number of layers in the model. Currently only supports 18 and 50. + - `num_layers`: Number of layers in the resnet model. Currently only supports 18 and 50. - `weights_init`: Initial weights setting for the model. Currently only supports "ImageNet". - **Optimization Parameters:** diff --git a/PW_FT_classification/configs/Raw/Crop_res50_plain_082723.yaml b/PW_FT_classification/configs/Raw/Crop_res18_plain_071824.yaml similarity index 85% rename from PW_FT_classification/configs/Raw/Crop_res50_plain_082723.yaml rename to PW_FT_classification/configs/Raw/Crop_res18_plain_071824.yaml index 8722cd88a..e0c4bb19b 100644 --- a/PW_FT_classification/configs/Raw/Crop_res50_plain_082723.yaml +++ b/PW_FT_classification/configs/Raw/Crop_res18_plain_071824.yaml @@ -1,8 +1,8 @@ # training -conf_id: Crop_Res50_plain_082723 +conf_id: Crop_Res18_plain_071824 algorithm: Plain log_dir: Crop -num_epochs: 60 +num_epochs: 30 log_interval: 10 parallel: 0 @@ -18,11 +18,12 @@ val_size: 0.2 split_data: True split_type: location # options are: random, location, sequence # data loading -batch_size: 256 -num_workers: 0 #40 +batch_size: 32 +num_workers: 4 #40 # model +num_classes: 2 model_name: PlainResNetClassifier -num_layers: 50 +num_layers: 18 weights_init: ImageNet # optim @@ -35,6 +36,6 @@ lr_classifier: 0.01 momentum_classifier: 0.9 weight_decay_classifier: 0.0005 ## lr_scheduler -step_size: 20 +step_size: 10 gamma: 0.1 diff --git a/PW_FT_classification/main.py b/PW_FT_classification/main.py index 41a719598..f9b79ea16 100644 --- a/PW_FT_classification/main.py +++ b/PW_FT_classification/main.py @@ -16,9 +16,10 @@ # %% from src.utils import batch_detection_cropping from src.utils import data_splitting -#app = typer.Typer() + +app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_show_locals=False) # %% -#@app.command() +@app.command() def main( config:str='./configs/Raw/Crop_res50_plain_082723.yaml', project:str='Custom-classification', @@ -34,7 +35,7 @@ def main( predict_root:str="" ): """ - Main function for training or evaluating a ResNet-50 model using PyTorch Lightning. + Main function for training or evaluating a ResNet model (50 or 18) using PyTorch Lightning. It loads configurations, initializes the model, logger, and other components based on provided arguments. Args: @@ -56,7 +57,7 @@ def main( gpus = gpus if torch.cuda.is_available() else None gpus = [int(i) for i in gpus.split(',')] - # Environment variable setup for numpy multi-threading + # Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues. os.environ["OMP_NUM_THREADS"] = str(np_threads) os.environ["OPENBLAS_NUM_THREADS"] = str(np_threads) os.environ["MKL_NUM_THREADS"] = str(np_threads) @@ -73,7 +74,6 @@ def main( # Set a global seed for reproducibility pl.seed_everything(seed) - # If the annotation directory does not have a data split, split the data first if conf.split_data: # Replace annotation dir from config with the directory containing the split files @@ -92,19 +92,18 @@ def main( train_annotations = os.path.join(conf.dataset_root, 'train_annotations.csv') test_annotations = os.path.join(conf.dataset_root, 'test_annotations.csv') val_annotations = os.path.join(conf.dataset_root, 'val_annotations.csv') - # Split training data - + # Crop training data batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), train_annotations) - # Split validation and test data + # Crop validation data batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), val_annotations) + # Crop test data (most likely we don't need this) batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), test_annotations) - # Dataset and algorithm loading based on the configuration dataset = datasets.__dict__[conf.dataset_name](conf=conf) learner = algorithms.__dict__[conf.algorithm](conf=conf, - train_class_counts=dataset.train_class_counts, - id_to_labels=dataset.id_to_labels) + train_class_counts=dataset.train_class_counts, + id_to_labels=dataset.id_to_labels) # Logger setup based on the specified logger type log_folder = 'log_dev' if dev else 'log' @@ -155,7 +154,7 @@ def main( devices=gpus, logger=None if evaluate is not None else logger, callbacks=[lr_monitor, checkpoint_callback], - strategy='ddp', + strategy='auto', num_sanity_val_steps=0, profiler=None ) @@ -171,8 +170,6 @@ def main( trainer.fit(learner, datamodule=dataset) # %% if __name__ == '__main__': - main() - - + app() # %% diff --git a/PW_FT_classification/requirements.txt b/PW_FT_classification/requirements.txt index 7d8fa85d9..c20d78a7c 100644 --- a/PW_FT_classification/requirements.txt +++ b/PW_FT_classification/requirements.txt @@ -1,125 +1,5 @@ -absl-py==2.1.0 -aiofiles==23.2.1 -aiohttp==3.9.3 -aiosignal==1.3.1 -altair==5.2.0 -annotated-types==0.6.0 -anyio==4.2.0 -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==23.2.0 -backcall==0.2.0 -cachetools==5.3.2 -certifi==2023.11.17 -charset-normalizer==3.3.2 -click==8.1.7 -colorama==0.4.6 -contourpy==1.1.1 -cycler==0.12.1 -decorator==5.1.1 -exceptiongroup==1.2.0 -executing==2.0.1 -fastapi==0.109.0 -ffmpy==0.3.1 -filelock==3.13.1 -fire==0.5.0 -fonttools==4.47.2 -frozenlist==1.4.1 -fsspec==2023.12.2 -google-auth==2.27.0 -google-auth-oauthlib==1.0.0 -gradio -grpcio==1.60.0 -h11==0.14.0 -httpcore==1.0.2 -httpx==0.26.0 -huggingface-hub==0.20.3 -idna==3.6 -importlib-metadata==7.0.1 -importlib-resources==6.1.1 -ipython==8.12.3 -jedi==0.19.1 -jinja2==3.1.3 -joblib==1.3.2 -jsonschema==4.21.1 -jsonschema-specifications==2023.12.1 -kiwisolver==1.4.5 -lightning-utilities==0.10.1 -markdown==3.5.2 -markdown-it-py==3.0.0 -markupsafe==2.1.4 -matplotlib==3.7.4 -matplotlib-inline==0.1.6 -mdurl==0.1.2 -multidict==6.0.4 -munch==2.5.0 -numpy==1.24.4 -oauthlib==3.2.2 -opencv-python==4.9.0.80 -opencv-python-headless==4.9.0.80 -orjson==3.9.12 -packaging==23.2 -pandas==2.0.3 -parso==0.8.3 -pexpect==4.9.0 -pickleshare==0.7.5 -pillow==10.1.0 -pkgutil-resolve-name==1.3.10 -prompt-toolkit==3.0.43 -protobuf==3.20.1 -psutil==5.9.8 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pyasn1==0.5.1 -pyasn1-modules==0.3.0 -pydantic==2.6.0 -pydantic-core==2.16.1 -pydub==0.25.1 -pygments==2.17.2 -pyparsing==3.1.1 -python-dateutil==2.8.2 -python-multipart==0.0.6 -pytorch-lightning==1.9.0 -pytorchwildlife -pytz==2023.4 -pyyaml==6.0.1 -referencing==0.33.0 -requests==2.31.0 -requests-oauthlib==1.3.1 -rich==13.7.0 -rpds-py==0.17.1 -rsa==4.9 -scikit-learn==1.2.0 -scipy==1.10.1 -seaborn==0.13.2 -semantic-version==2.10.0 -shellingham==1.5.4 -six==1.16.0 -sniffio==1.3.0 -stack-data==0.6.3 -starlette==0.35.1 -supervision==0.16.0 -tensorboard==2.14.0 -tensorboard-data-server==0.7.2 -termcolor==2.4.0 -thop==0.1.1-2209072238 -threadpoolctl==3.2.0 -tomlkit==0.12.0 -toolz==0.12.1 -torch==1.10.1 -torchaudio==0.10.1 -torchmetrics==1.3.0.post0 -torchvision==0.11.2 -tqdm==4.66.1 -traitlets==5.14.1 -typer==0.9.0 -typing-extensions==4.9.0 -tzdata==2023.4 -ultralytics-yolov5==0.1.1 -urllib3==2.2.0 -uvicorn==0.27.0.post1 -wcwidth==0.2.13 -websockets==11.0.3 -werkzeug==3.0.1 -yarl==1.9.4 -zipp==3.17.0 \ No newline at end of file +PytorchWildlife +scikit_learn +lightning +munch +typer \ No newline at end of file diff --git a/PW_FT_classification/src/algorithms/plain.py b/PW_FT_classification/src/algorithms/plain.py index e6bf18abf..79442e3b5 100644 --- a/PW_FT_classification/src/algorithms/plain.py +++ b/PW_FT_classification/src/algorithms/plain.py @@ -41,7 +41,7 @@ def __init__(self, conf, train_class_counts, id_to_labels, **kwargs): self.save_hyperparameters(ignore=['conf', 'train_class_counts']) self.train_class_counts = train_class_counts self.id_to_labels = id_to_labels - self.net = models.__dict__[self.hparams.model_name](num_cls=1, + self.net = models.__dict__[self.hparams.model_name](num_cls=self.hparams.num_classes, num_layers=self.hparams.num_layers) def configure_optimizers(self): @@ -93,9 +93,8 @@ def training_step(self, batch, batch_idx): # Forward pass feats = self.net.feature(data) logits = self.net.classifier(feats) - logits = logits.squeeze(1) # Calculate loss - loss = self.net.criterion_cls(logits, label_ids.float()) + loss = self.net.criterion_cls(logits, label_ids) self.log("train_loss", loss) return loss @@ -117,8 +116,8 @@ def validation_step(self, batch, batch_idx): data, label_ids = batch[0], batch[1] # Forward pass feats = self.net.feature(data) - logits = self.net.classifier(feats).squeeze(1) - preds = logits>0.5 + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) self.val_st_outs.append((preds.detach().cpu().numpy(), label_ids.detach().cpu().numpy())) @@ -148,8 +147,8 @@ def test_step(self, batch, batch_idx): data, label_ids, labels, file_ids = batch # Forward pass feats = self.net.feature(data) - logits = torch.sigmoid(self.net.classifier(feats)) - preds = logits>0.5 + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) self.te_st_outs.append((preds.detach().cpu().numpy(), label_ids.detach().cpu().numpy(), @@ -198,8 +197,8 @@ def predict_step(self, batch, batch_idx): data, file_ids = batch # Forward pass feats = self.net.feature(data) - logits = torch.sigmoid(self.net.classifier(feats)) - preds = logits>0.5 + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) self.pr_st_outs.append((preds.detach().cpu().numpy(), feats.detach().cpu().numpy(), diff --git a/PW_FT_classification/src/datasets/custom.py b/PW_FT_classification/src/datasets/custom.py index 75ada44b3..ddcb660f6 100644 --- a/PW_FT_classification/src/datasets/custom.py +++ b/PW_FT_classification/src/datasets/custom.py @@ -122,6 +122,7 @@ def __getitem__(self, index): return sample, label_id, label, file_dir + class Custom_Crop_DS(Custom_Base_DS): """ Dataset class for handling custom cropped datasets. @@ -146,6 +147,7 @@ def __init__(self, rootdir, dset='train', transform=None): .format('test' if dset == 'test' else dset))) self.load_data() + class Custom_Base(pl.LightningDataModule): """ Base data module for handling custom datasets in PyTorch Lightning. @@ -213,6 +215,7 @@ def predict_dataloader(self): """ return DataLoader(self.dset_pr, batch_size=64, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + class Custom_Crop(Custom_Base): """ Custom data module specifically for cropped datasets in PyTorch Lightning. diff --git a/PW_FT_classification/src/models/plain_resnet.py b/PW_FT_classification/src/models/plain_resnet.py index 381295810..d664c91b7 100644 --- a/PW_FT_classification/src/models/plain_resnet.py +++ b/PW_FT_classification/src/models/plain_resnet.py @@ -144,7 +144,7 @@ def setup_criteria(self): Set up the criterion for the classifier. """ # Criterion for binary classification - self.criterion_cls = nn.BCEWithLogitsLoss() + self.criterion_cls = nn.CrossEntropyLoss() def feat_init(self): """ @@ -167,4 +167,3 @@ def feat_init(self): unused_keys = load_keys - self_keys print('missing keys: {}'.format(sorted(list(missing_keys)))) print('unused_keys: {}'.format(sorted(list(unused_keys)))) - pass diff --git a/PW_FT_classification/src/utils/data_splitting.py b/PW_FT_classification/src/utils/data_splitting.py index 3890e23df..f667c6247 100644 --- a/PW_FT_classification/src/utils/data_splitting.py +++ b/PW_FT_classification/src/utils/data_splitting.py @@ -50,9 +50,6 @@ def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1): # Return the dataframes return train_set, val_set, test_set -import pandas as pd -from sklearn.model_selection import train_test_split - def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None): """ Splits the dataset into train, validation, and test sets based on location, ensuring that: @@ -95,10 +92,6 @@ def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, ra return train_data, val_data, test_data - -import pandas as pd -from sklearn.model_selection import train_test_split - def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None): """ Splits the dataset into train, validation, and test sets based on sequence ID, ensuring that: