Skip to content

Commit

Permalink
Merge pull request #519 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
Classification module debugging
  • Loading branch information
zhmiao committed Jul 19, 2024
2 parents dbd39d4 + 02f2a4d commit 2b16ad1
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 166 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ PytorchWildlife.egg-info/
*setup/*
*dist/*
*VSCodeCounter*
*build/*
*build/*
*bee*
*Bee*
*debug*
*annotations.csv
*cropped_resized*
*log*
3 changes: 2 additions & 1 deletion PW_FT_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ Before training your model, you need to configure the training and data paramete
- `num_workers`: Number of subprocesses to use for data loading.

- **Model Parameters:**
- `num_classes`: The number of classes.
- `model_name`: The name of the model architecture to use. The current version only supports PlainResNetClassifier.
- `num_layers`: Number of layers in the model. Currently only supports 18 and 50.
- `num_layers`: Number of layers in the resnet model. Currently only supports 18 and 50.
- `weights_init`: Initial weights setting for the model. Currently only supports "ImageNet".

- **Optimization Parameters:**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# training
conf_id: Crop_Res50_plain_082723
conf_id: Crop_Res18_plain_071824
algorithm: Plain
log_dir: Crop
num_epochs: 60
num_epochs: 30
log_interval: 10
parallel: 0

Expand All @@ -18,11 +18,12 @@ val_size: 0.2
split_data: True
split_type: location # options are: random, location, sequence
# data loading
batch_size: 256
num_workers: 0 #40
batch_size: 32
num_workers: 4 #40
# model
num_classes: 2
model_name: PlainResNetClassifier
num_layers: 50
num_layers: 18
weights_init: ImageNet

# optim
Expand All @@ -35,6 +36,6 @@ lr_classifier: 0.01
momentum_classifier: 0.9
weight_decay_classifier: 0.0005
## lr_scheduler
step_size: 20
step_size: 10
gamma: 0.1

27 changes: 12 additions & 15 deletions PW_FT_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# %%
from src.utils import batch_detection_cropping
from src.utils import data_splitting
#app = typer.Typer()

app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_show_locals=False)
# %%
#@app.command()
@app.command()
def main(
config:str='./configs/Raw/Crop_res50_plain_082723.yaml',
project:str='Custom-classification',
Expand All @@ -34,7 +35,7 @@ def main(
predict_root:str=""
):
"""
Main function for training or evaluating a ResNet-50 model using PyTorch Lightning.
Main function for training or evaluating a ResNet model (50 or 18) using PyTorch Lightning.
It loads configurations, initializes the model, logger, and other components based on provided arguments.
Args:
Expand All @@ -56,7 +57,7 @@ def main(
gpus = gpus if torch.cuda.is_available() else None
gpus = [int(i) for i in gpus.split(',')]

# Environment variable setup for numpy multi-threading
# Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues.
os.environ["OMP_NUM_THREADS"] = str(np_threads)
os.environ["OPENBLAS_NUM_THREADS"] = str(np_threads)
os.environ["MKL_NUM_THREADS"] = str(np_threads)
Expand All @@ -73,7 +74,6 @@ def main(
# Set a global seed for reproducibility
pl.seed_everything(seed)


# If the annotation directory does not have a data split, split the data first
if conf.split_data:
# Replace annotation dir from config with the directory containing the split files
Expand All @@ -92,19 +92,18 @@ def main(
train_annotations = os.path.join(conf.dataset_root, 'train_annotations.csv')
test_annotations = os.path.join(conf.dataset_root, 'test_annotations.csv')
val_annotations = os.path.join(conf.dataset_root, 'val_annotations.csv')
# Split training data

# Crop training data
batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), train_annotations)
# Split validation and test data
# Crop validation data
batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), val_annotations)
# Crop test data (most likely we don't need this)
batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), test_annotations)


# Dataset and algorithm loading based on the configuration
dataset = datasets.__dict__[conf.dataset_name](conf=conf)
learner = algorithms.__dict__[conf.algorithm](conf=conf,
train_class_counts=dataset.train_class_counts,
id_to_labels=dataset.id_to_labels)
train_class_counts=dataset.train_class_counts,
id_to_labels=dataset.id_to_labels)

# Logger setup based on the specified logger type
log_folder = 'log_dev' if dev else 'log'
Expand Down Expand Up @@ -155,7 +154,7 @@ def main(
devices=gpus,
logger=None if evaluate is not None else logger,
callbacks=[lr_monitor, checkpoint_callback],
strategy='ddp',
strategy='auto',
num_sanity_val_steps=0,
profiler=None
)
Expand All @@ -171,8 +170,6 @@ def main(
trainer.fit(learner, datamodule=dataset)
# %%
if __name__ == '__main__':
main()


app()

# %%
130 changes: 5 additions & 125 deletions PW_FT_classification/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,125 +1,5 @@
absl-py==2.1.0
aiofiles==23.2.1
aiohttp==3.9.3
aiosignal==1.3.1
altair==5.2.0
annotated-types==0.6.0
anyio==4.2.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
backcall==0.2.0
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
contourpy==1.1.1
cycler==0.12.1
decorator==5.1.1
exceptiongroup==1.2.0
executing==2.0.1
fastapi==0.109.0
ffmpy==0.3.1
filelock==3.13.1
fire==0.5.0
fonttools==4.47.2
frozenlist==1.4.1
fsspec==2023.12.2
google-auth==2.27.0
google-auth-oauthlib==1.0.0
gradio
grpcio==1.60.0
h11==0.14.0
httpcore==1.0.2
httpx==0.26.0
huggingface-hub==0.20.3
idna==3.6
importlib-metadata==7.0.1
importlib-resources==6.1.1
ipython==8.12.3
jedi==0.19.1
jinja2==3.1.3
joblib==1.3.2
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
lightning-utilities==0.10.1
markdown==3.5.2
markdown-it-py==3.0.0
markupsafe==2.1.4
matplotlib==3.7.4
matplotlib-inline==0.1.6
mdurl==0.1.2
multidict==6.0.4
munch==2.5.0
numpy==1.24.4
oauthlib==3.2.2
opencv-python==4.9.0.80
opencv-python-headless==4.9.0.80
orjson==3.9.12
packaging==23.2
pandas==2.0.3
parso==0.8.3
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.1.0
pkgutil-resolve-name==1.3.10
prompt-toolkit==3.0.43
protobuf==3.20.1
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pydantic==2.6.0
pydantic-core==2.16.1
pydub==0.25.1
pygments==2.17.2
pyparsing==3.1.1
python-dateutil==2.8.2
python-multipart==0.0.6
pytorch-lightning==1.9.0
pytorchwildlife
pytz==2023.4
pyyaml==6.0.1
referencing==0.33.0
requests==2.31.0
requests-oauthlib==1.3.1
rich==13.7.0
rpds-py==0.17.1
rsa==4.9
scikit-learn==1.2.0
scipy==1.10.1
seaborn==0.13.2
semantic-version==2.10.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.0
stack-data==0.6.3
starlette==0.35.1
supervision==0.16.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
termcolor==2.4.0
thop==0.1.1-2209072238
threadpoolctl==3.2.0
tomlkit==0.12.0
toolz==0.12.1
torch==1.10.1
torchaudio==0.10.1
torchmetrics==1.3.0.post0
torchvision==0.11.2
tqdm==4.66.1
traitlets==5.14.1
typer==0.9.0
typing-extensions==4.9.0
tzdata==2023.4
ultralytics-yolov5==0.1.1
urllib3==2.2.0
uvicorn==0.27.0.post1
wcwidth==0.2.13
websockets==11.0.3
werkzeug==3.0.1
yarl==1.9.4
zipp==3.17.0
PytorchWildlife
scikit_learn
lightning
munch
typer
17 changes: 8 additions & 9 deletions PW_FT_classification/src/algorithms/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, conf, train_class_counts, id_to_labels, **kwargs):
self.save_hyperparameters(ignore=['conf', 'train_class_counts'])
self.train_class_counts = train_class_counts
self.id_to_labels = id_to_labels
self.net = models.__dict__[self.hparams.model_name](num_cls=1,
self.net = models.__dict__[self.hparams.model_name](num_cls=self.hparams.num_classes,
num_layers=self.hparams.num_layers)

def configure_optimizers(self):
Expand Down Expand Up @@ -93,9 +93,8 @@ def training_step(self, batch, batch_idx):
# Forward pass
feats = self.net.feature(data)
logits = self.net.classifier(feats)
logits = logits.squeeze(1)
# Calculate loss
loss = self.net.criterion_cls(logits, label_ids.float())
loss = self.net.criterion_cls(logits, label_ids)
self.log("train_loss", loss)

return loss
Expand All @@ -117,8 +116,8 @@ def validation_step(self, batch, batch_idx):
data, label_ids = batch[0], batch[1]
# Forward pass
feats = self.net.feature(data)
logits = self.net.classifier(feats).squeeze(1)
preds = logits>0.5
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)

self.val_st_outs.append((preds.detach().cpu().numpy(),
label_ids.detach().cpu().numpy()))
Expand Down Expand Up @@ -148,8 +147,8 @@ def test_step(self, batch, batch_idx):
data, label_ids, labels, file_ids = batch
# Forward pass
feats = self.net.feature(data)
logits = torch.sigmoid(self.net.classifier(feats))
preds = logits>0.5
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)

self.te_st_outs.append((preds.detach().cpu().numpy(),
label_ids.detach().cpu().numpy(),
Expand Down Expand Up @@ -198,8 +197,8 @@ def predict_step(self, batch, batch_idx):
data, file_ids = batch
# Forward pass
feats = self.net.feature(data)
logits = torch.sigmoid(self.net.classifier(feats))
preds = logits>0.5
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)

self.pr_st_outs.append((preds.detach().cpu().numpy(),
feats.detach().cpu().numpy(),
Expand Down
3 changes: 3 additions & 0 deletions PW_FT_classification/src/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __getitem__(self, index):

return sample, label_id, label, file_dir


class Custom_Crop_DS(Custom_Base_DS):
"""
Dataset class for handling custom cropped datasets.
Expand All @@ -146,6 +147,7 @@ def __init__(self, rootdir, dset='train', transform=None):
.format('test' if dset == 'test' else dset)))
self.load_data()


class Custom_Base(pl.LightningDataModule):
"""
Base data module for handling custom datasets in PyTorch Lightning.
Expand Down Expand Up @@ -213,6 +215,7 @@ def predict_dataloader(self):
"""
return DataLoader(self.dset_pr, batch_size=64, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False)


class Custom_Crop(Custom_Base):
"""
Custom data module specifically for cropped datasets in PyTorch Lightning.
Expand Down
3 changes: 1 addition & 2 deletions PW_FT_classification/src/models/plain_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setup_criteria(self):
Set up the criterion for the classifier.
"""
# Criterion for binary classification
self.criterion_cls = nn.BCEWithLogitsLoss()
self.criterion_cls = nn.CrossEntropyLoss()

def feat_init(self):
"""
Expand All @@ -167,4 +167,3 @@ def feat_init(self):
unused_keys = load_keys - self_keys
print('missing keys: {}'.format(sorted(list(missing_keys))))
print('unused_keys: {}'.format(sorted(list(unused_keys))))
pass
7 changes: 0 additions & 7 deletions PW_FT_classification/src/utils/data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1):
# Return the dataframes
return train_set, val_set, test_set

import pandas as pd
from sklearn.model_selection import train_test_split

def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None):
"""
Splits the dataset into train, validation, and test sets based on location, ensuring that:
Expand Down Expand Up @@ -95,10 +92,6 @@ def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, ra
return train_data, val_data, test_data



import pandas as pd
from sklearn.model_selection import train_test_split

def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None):
"""
Splits the dataset into train, validation, and test sets based on sequence ID, ensuring that:
Expand Down

0 comments on commit 2b16ad1

Please sign in to comment.