-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit d3467c9
Showing
48 changed files
with
278,952 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
**/__pycache__ | ||
checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# [Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988) | ||
|
||
Official PyTorch Implementation of **HSPNet**, from the following paper: | ||
|
||
[Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988). ACMMM 2023. | ||
|
||
> Ao Wang, Hui Chen, Zijia Lin, Zixuan Ding, Pengzhang Liu, Yongjun Bao, Weipeng Yan, and Guiguang Ding | ||
**Abstract** | ||
|
||
Collecting full annotations to construct multi-label datasets is difficult and labor-consuming. As an effective solution to relieve the annotation burden, single positive multi-label learning (SPML) draws increasing attention from both academia and industry. It only annotates each image with one positive label, leaving other labels unobserved. Therefore, existing methods strive to explore the cue of unobserved labels to compensate for the insufficiency of label supervision. Though achieving promising performance, they generally consider labels independently, leaving out the inherent hierarchical semantic relationship among labels which reveals that labels can be clustered into groups. In this paper, we propose a hierarchical prompt learning method with a novel Hierarchical Semantic Prompt Network (HSPNet) to harness such hierarchical semantic relationships using a large-scale pretrained vision and language model, i.e., CLIP, for SPML. We first introduce a Hierarchical Conditional Prompt (HCP) strategy to grasp the hierarchical label-group dependency. Then we equip a Hierarchical Graph Convolutional Network (HGCN) to capture the high-order inter-label and inter-group dependencies. Comprehensive experiments and analyses on several benchmark datasets show that our method significantly outperforms the state-of-the-art methods, well demonstrating its superiority and effectiveness. | ||
|
||
## Credit to previous work | ||
This repository is built upon the code base of [ASL](https://github.com/Alibaba-MIIL/ASL) and [SPLC](https://github.com/xinyu1205/robust-loss-mlml), thanks very much! | ||
|
||
## Performance | ||
|
||
| Dataset | mAP | Ckpt | Log | | ||
|:---: | :---: | :---: | :---: | | ||
| COCO | 75.7 | [hspnet+coco.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+coco.ckpt) | [hspnet+coco.txt](logs/hspnet+coco.txt) | | ||
| VOC | 90.4 | [hspnet+voc.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+voc.ckpt) | [hspnet+voc.txt](logs/hspnet+voc.txt) | | ||
| NUSWIDE | 61.8 | [hspnet+nuswide.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+nuswide.ckpt) | [hspnet+nuswide.txt](logs/hspnet+nuswide.txt) | | ||
| CUB | 24.3 | [hspnet+cub.ckpt]() | [hspnet+cub.txt](logs/hspnet+cub.txt) | | ||
|
||
## Training | ||
|
||
### COCO | ||
```python | ||
python train.py -c configs/hspnet+coco.yaml | ||
``` | ||
|
||
### VOC | ||
```python | ||
python train.py -c configs/hspnet+voc.yaml | ||
``` | ||
|
||
### NUSWIDE | ||
```python | ||
python train.py -c configs/hspnet+nuswide.yaml | ||
``` | ||
|
||
### CUB | ||
```python | ||
python train.py -c configs/hspnet+cub.yaml | ||
``` | ||
|
||
## Inference | ||
|
||
> Note: Please place the pretrained checkpoint to checkpoints/hspnet+coco/round1/model-highest.ckpt | ||
#### COCO | ||
```python | ||
python train.py -c configs/hspnet+coco.yaml -t -r 1 | ||
``` | ||
|
||
#### VOC | ||
```python | ||
python train.py -c configs/hspnet+voc.yaml -t -r 1 | ||
``` | ||
|
||
#### NUSWIDE | ||
```python | ||
python train.py -c configs/hspnet+nuswide.yaml -t -r 1 | ||
``` | ||
|
||
#### CUB | ||
```python | ||
python train.py -c configs/hspnet+cub.yaml -t -r 1 | ||
``` | ||
|
||
## Citation | ||
``` | ||
@inproceedings{wang2023hierarchical, | ||
title={Hierarchical prompt learning using clip for multi-label classification with single positive labels}, | ||
author={Wang, Ao and Chen, Hui and Lin, Zijia and Ding, Zixuan and Liu, Pengzhang and Bao, Yongjun and Yan, Weipeng and Ding, Guiguang}, | ||
booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, | ||
pages={5594--5604}, | ||
year={2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training') | ||
parser.add_argument('-c', | ||
'--config-file', | ||
help='config file', | ||
default='configs/base.yaml', | ||
type=str) | ||
parser.add_argument('-t', | ||
'--test', | ||
help='run test', | ||
default=False, | ||
action="store_true") | ||
parser.add_argument('-r', '--round', help='round', default=1, type=int) | ||
parser.add_argument('--resume', default=False, action='store_true') | ||
args = parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .clip import * |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
import hashlib | ||
import os | ||
import urllib | ||
import warnings | ||
from typing import List, Union | ||
|
||
import torch | ||
from PIL import Image | ||
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, | ||
ToTensor) | ||
from tqdm import tqdm | ||
|
||
from .model import build_model | ||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer | ||
|
||
try: | ||
from torchvision.transforms import InterpolationMode | ||
BICUBIC = InterpolationMode.BICUBIC | ||
except ImportError: | ||
BICUBIC = Image.BICUBIC | ||
|
||
if torch.__version__.split(".") < ["1", "7", "1"]: | ||
warnings.warn("PyTorch version 1.7.1 or higher is recommended") | ||
|
||
__all__ = ["available_models", "load", "tokenize"] | ||
_tokenizer = _Tokenizer() | ||
|
||
_MODELS = { | ||
"RN50": | ||
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", | ||
"RN101": | ||
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", | ||
"RN50x4": | ||
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", | ||
"RN50x16": | ||
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", | ||
"ViT-B/32": | ||
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", | ||
"ViT-B/16": | ||
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", | ||
} | ||
|
||
|
||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): | ||
os.makedirs(root, exist_ok=True) | ||
filename = os.path.basename(url) | ||
|
||
expected_sha256 = url.split("/")[-2] | ||
download_target = os.path.join(root, filename) | ||
|
||
if os.path.exists(download_target) and not os.path.isfile(download_target): | ||
raise RuntimeError( | ||
f"{download_target} exists and is not a regular file") | ||
|
||
if os.path.isfile(download_target): | ||
if hashlib.sha256(open(download_target, | ||
"rb").read()).hexdigest() == expected_sha256: | ||
return download_target | ||
else: | ||
warnings.warn( | ||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" | ||
) | ||
|
||
with urllib.request.urlopen(url) as source, open(download_target, | ||
"wb") as output: | ||
with tqdm(total=int(source.info().get("Content-Length")), | ||
ncols=80, | ||
unit='iB', | ||
unit_scale=True) as loop: | ||
while True: | ||
buffer = source.read(8192) | ||
if not buffer: | ||
break | ||
|
||
output.write(buffer) | ||
loop.update(len(buffer)) | ||
|
||
if hashlib.sha256(open(download_target, | ||
"rb").read()).hexdigest() != expected_sha256: | ||
raise RuntimeError( | ||
"Model has been downloaded but the SHA256 checksum does not not match" | ||
) | ||
|
||
return download_target | ||
|
||
|
||
def _transform(n_px): | ||
return Compose([ | ||
Resize(n_px, interpolation=BICUBIC), | ||
CenterCrop(n_px), | ||
lambda image: image.convert("RGB"), | ||
ToTensor(), | ||
Normalize((0.48145466, 0.4578275, 0.40821073), | ||
(0.26862954, 0.26130258, 0.27577711)), | ||
]) | ||
|
||
|
||
def available_models() -> List[str]: | ||
"""Returns the names of available CLIP models""" | ||
return list(_MODELS.keys()) | ||
|
||
|
||
def load(name: str, | ||
device: Union[str, torch.device] = "cuda" | ||
if torch.cuda.is_available() else "cpu", | ||
jit=False): | ||
"""Load a CLIP model | ||
Parameters | ||
---------- | ||
name : str | ||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict | ||
device : Union[str, torch.device] | ||
The device to put the loaded model | ||
jit : bool | ||
Whether to load the optimized JIT model or more hackable non-JIT model (default). | ||
Returns | ||
------- | ||
model : torch.nn.Module | ||
The CLIP model | ||
preprocess : Callable[[PIL.Image], torch.Tensor] | ||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input | ||
""" | ||
if name in _MODELS: | ||
model_path = _download(_MODELS[name]) | ||
elif os.path.isfile(name): | ||
model_path = name | ||
else: | ||
raise RuntimeError( | ||
f"Model {name} not found; available models = {available_models()}") | ||
|
||
try: | ||
# loading JIT archive | ||
model = torch.jit.load(model_path, | ||
map_location=device if jit else "cpu").eval() | ||
state_dict = None | ||
except RuntimeError: | ||
# loading saved state dict | ||
if jit: | ||
warnings.warn( | ||
f"File {model_path} is not a JIT archive. Loading as a state dict instead" | ||
) | ||
jit = False | ||
state_dict = torch.load(model_path, map_location="cpu") | ||
|
||
if not jit: | ||
model = build_model(state_dict or model.state_dict()).to(device) | ||
if str(device) == "cpu": | ||
model.float() | ||
return model, _transform(model.visual.input_resolution) | ||
|
||
# patch the device names | ||
device_holder = torch.jit.trace( | ||
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | ||
device_node = [ | ||
n for n in device_holder.graph.findAllNodes("prim::Constant") | ||
if "Device" in repr(n) | ||
][-1] | ||
|
||
def patch_device(module): | ||
try: | ||
graphs = [module.graph] if hasattr(module, "graph") else [] | ||
except RuntimeError: | ||
graphs = [] | ||
|
||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("prim::Constant"): | ||
if "value" in node.attributeNames() and str( | ||
node["value"]).startswith("cuda"): | ||
node.copyAttributes(device_node) | ||
|
||
model.apply(patch_device) | ||
patch_device(model.encode_image) | ||
patch_device(model.encode_text) | ||
|
||
# patch dtype to float32 on CPU | ||
if str(device) == "cpu": | ||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), | ||
example_inputs=[]) | ||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] | ||
float_node = float_input.node() | ||
|
||
def patch_float(module): | ||
try: | ||
graphs = [module.graph] if hasattr(module, "graph") else [] | ||
except RuntimeError: | ||
graphs = [] | ||
|
||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("aten::to"): | ||
inputs = list(node.inputs()) | ||
for i in [ | ||
1, 2 | ||
]: # dtype can be the second or third argument to aten::to() | ||
if inputs[i].node()["value"] == 5: | ||
inputs[i].node().copyAttributes(float_node) | ||
|
||
model.apply(patch_float) | ||
patch_float(model.encode_image) | ||
patch_float(model.encode_text) | ||
|
||
model.float() | ||
|
||
return model, _transform(model.input_resolution.item()) | ||
|
||
|
||
def tokenize(texts: Union[str, List[str]], | ||
context_length: int = 77, | ||
truncate: bool = False) -> torch.LongTensor: | ||
""" | ||
Returns the tokenized representation of given input string(s) | ||
Parameters | ||
---------- | ||
texts : Union[str, List[str]] | ||
An input string or a list of input strings to tokenize | ||
context_length : int | ||
The context length to use; all CLIP models use 77 as the context length | ||
truncate: bool | ||
Whether to truncate the text in case its encoding is longer than the context length | ||
Returns | ||
------- | ||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | ||
""" | ||
if isinstance(texts, str): | ||
texts = [texts] | ||
|
||
sot_token = _tokenizer.encoder["<|startoftext|>"] | ||
eot_token = _tokenizer.encoder["<|endoftext|>"] | ||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] | ||
for text in texts] | ||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | ||
|
||
for i, tokens in enumerate(all_tokens): | ||
if len(tokens) > context_length: | ||
if truncate: | ||
tokens = tokens[:context_length] | ||
tokens[-1] = eot_token | ||
else: | ||
raise RuntimeError( | ||
f"Input {texts[i]} is too long for context length {context_length}" | ||
) | ||
result[i, :len(tokens)] = torch.tensor(tokens) | ||
|
||
return result |
Oops, something went wrong.