forked from hila-chefer/TargetCLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d765033
commit 38f636a
Showing
86 changed files
with
7,955 additions
and
2,249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
.idea/* | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
''' | ||
* Copyright (c) 2022, salesforce.com, inc. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
* By Junnan Li | ||
''' | ||
import argparse | ||
import os | ||
import ruamel_yaml as yaml | ||
import numpy as np | ||
import random | ||
import time | ||
import datetime | ||
import json | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.backends.cudnn as cudnn | ||
import torch.distributed as dist | ||
from torch.utils.data import DataLoader | ||
|
||
from models.blip import blip_decoder | ||
import utils | ||
from data import create_dataset, create_sampler, create_loader | ||
from data.utils import save_result | ||
|
||
@torch.no_grad() | ||
def evaluate(model, data_loader, device, config): | ||
# evaluate | ||
model.eval() | ||
|
||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
header = 'Evaluation:' | ||
print_freq = 10 | ||
|
||
result = [] | ||
for image, image_id in metric_logger.log_every(data_loader, print_freq, header): | ||
|
||
image = image.to(device) | ||
|
||
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], | ||
min_length=config['min_length'], repetition_penalty=1.1) | ||
|
||
for caption, img_id in zip(captions, image_id): | ||
result.append({"image_id": img_id.item(), "caption": caption}) | ||
|
||
return result | ||
|
||
|
||
def main(args, config): | ||
utils.init_distributed_mode(args) | ||
|
||
device = torch.device(args.device) | ||
|
||
# fix the seed for reproducibility | ||
seed = args.seed + utils.get_rank() | ||
torch.manual_seed(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
cudnn.benchmark = True | ||
|
||
#### Dataset #### | ||
print("Creating captioning dataset") | ||
val_dataset, test_dataset = create_dataset('nocaps', config) | ||
|
||
if args.distributed: | ||
num_tasks = utils.get_world_size() | ||
global_rank = utils.get_rank() | ||
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) | ||
else: | ||
samplers = [None,None] | ||
|
||
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, | ||
batch_size=[config['batch_size']]*2,num_workers=[4,4], | ||
is_trains=[False, False], collate_fns=[None,None]) | ||
|
||
#### Model #### | ||
print("Creating model") | ||
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], | ||
prompt=config['prompt']) | ||
|
||
model = model.to(device) | ||
|
||
model_without_ddp = model | ||
if args.distributed: | ||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | ||
model_without_ddp = model.module | ||
|
||
val_result = evaluate(model_without_ddp, val_loader, device, config) | ||
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') | ||
test_result = evaluate(model_without_ddp, test_loader, device, config) | ||
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--config', default='./configs/nocaps.yaml') | ||
parser.add_argument('--output_dir', default='output/NoCaps') | ||
parser.add_argument('--device', default='cuda') | ||
parser.add_argument('--seed', default=42, type=int) | ||
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') | ||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') | ||
parser.add_argument('--distributed', default=True, type=bool) | ||
args = parser.parse_args() | ||
|
||
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) | ||
|
||
args.result_dir = os.path.join(args.output_dir, 'result') | ||
|
||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
Path(args.result_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) | ||
|
||
main(args, config) |
Empty file.
Oops, something went wrong.