Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslahm committed Jul 18, 2024
0 parents commit d3467c9
Show file tree
Hide file tree
Showing 48 changed files with 278,952 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**/__pycache__
checkpoints
80 changes: 80 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# [Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988)

Official PyTorch Implementation of **HSPNet**, from the following paper:

[Hierarchical Prompt Learning Using CLIP for Multi-label Classification with Single Positive Labels](https://dl.acm.org/doi/pdf/10.1145/3581783.3611988). ACMMM 2023.

> Ao Wang, Hui Chen, Zijia Lin, Zixuan Ding, Pengzhang Liu, Yongjun Bao, Weipeng Yan, and Guiguang Ding
**Abstract**

Collecting full annotations to construct multi-label datasets is difficult and labor-consuming. As an effective solution to relieve the annotation burden, single positive multi-label learning (SPML) draws increasing attention from both academia and industry. It only annotates each image with one positive label, leaving other labels unobserved. Therefore, existing methods strive to explore the cue of unobserved labels to compensate for the insufficiency of label supervision. Though achieving promising performance, they generally consider labels independently, leaving out the inherent hierarchical semantic relationship among labels which reveals that labels can be clustered into groups. In this paper, we propose a hierarchical prompt learning method with a novel Hierarchical Semantic Prompt Network (HSPNet) to harness such hierarchical semantic relationships using a large-scale pretrained vision and language model, i.e., CLIP, for SPML. We first introduce a Hierarchical Conditional Prompt (HCP) strategy to grasp the hierarchical label-group dependency. Then we equip a Hierarchical Graph Convolutional Network (HGCN) to capture the high-order inter-label and inter-group dependencies. Comprehensive experiments and analyses on several benchmark datasets show that our method significantly outperforms the state-of-the-art methods, well demonstrating its superiority and effectiveness.

## Credit to previous work
This repository is built upon the code base of [ASL](https://github.com/Alibaba-MIIL/ASL) and [SPLC](https://github.com/xinyu1205/robust-loss-mlml), thanks very much!

## Performance

| Dataset | mAP | Ckpt | Log |
|:---: | :---: | :---: | :---: |
| COCO | 75.7 | [hspnet+coco.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+coco.ckpt) | [hspnet+coco.txt](logs/hspnet+coco.txt) |
| VOC | 90.4 | [hspnet+voc.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+voc.ckpt) | [hspnet+voc.txt](logs/hspnet+voc.txt) |
| NUSWIDE | 61.8 | [hspnet+nuswide.ckpt](https://github.com/jameslahm/HSPNet/releases/download/v1.0/hspnet+nuswide.ckpt) | [hspnet+nuswide.txt](logs/hspnet+nuswide.txt) |
| CUB | 24.3 | [hspnet+cub.ckpt]() | [hspnet+cub.txt](logs/hspnet+cub.txt) |

## Training

### COCO
```python
python train.py -c configs/hspnet+coco.yaml
```

### VOC
```python
python train.py -c configs/hspnet+voc.yaml
```

### NUSWIDE
```python
python train.py -c configs/hspnet+nuswide.yaml
```

### CUB
```python
python train.py -c configs/hspnet+cub.yaml
```

## Inference

> Note: Please place the pretrained checkpoint to checkpoints/hspnet+coco/round1/model-highest.ckpt
#### COCO
```python
python train.py -c configs/hspnet+coco.yaml -t -r 1
```

#### VOC
```python
python train.py -c configs/hspnet+voc.yaml -t -r 1
```

#### NUSWIDE
```python
python train.py -c configs/hspnet+nuswide.yaml -t -r 1
```

#### CUB
```python
python train.py -c configs/hspnet+cub.yaml -t -r 1
```

## Citation
```
@inproceedings{wang2023hierarchical,
title={Hierarchical prompt learning using clip for multi-label classification with single positive labels},
author={Wang, Ao and Chen, Hui and Lin, Zijia and Ding, Zixuan and Liu, Pengzhang and Bao, Yongjun and Yan, Weipeng and Ding, Guiguang},
booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
pages={5594--5604},
year={2023}
}
```
16 changes: 16 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import argparse

parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
parser.add_argument('-c',
'--config-file',
help='config file',
default='configs/base.yaml',
type=str)
parser.add_argument('-t',
'--test',
help='run test',
default=False,
action="store_true")
parser.add_argument('-r', '--round', help='round', default=1, type=int)
parser.add_argument('--resume', default=False, action='store_true')
args = parser.parse_args()
1 change: 1 addition & 0 deletions clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .clip import *
Binary file added clip/bpe_simple_vocab_16e6.txt.gz
Binary file not shown.
258 changes: 258 additions & 0 deletions clip/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import hashlib
import os
import urllib
import warnings
from typing import List, Union

import torch
from PIL import Image
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)
from tqdm import tqdm

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC

if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")

__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()

_MODELS = {
"RN50":
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101":
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4":
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16":
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32":
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16":
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}


def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
if hashlib.sha256(open(download_target,
"rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target,
"wb") as output:
with tqdm(total=int(source.info().get("Content-Length")),
ncols=80,
unit='iB',
unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

if hashlib.sha256(open(download_target,
"rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match"
)

return download_target


def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])


def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())


def load(name: str,
device: Union[str, torch.device] = "cuda"
if torch.cuda.is_available() else "cpu",
jit=False):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}")

try:
# loading JIT archive
model = torch.jit.load(model_path,
map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
)
jit = False
state_dict = torch.load(model_path, map_location="cpu")

if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)

# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes("prim::Constant")
if "Device" in repr(n)
][-1]

def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []

if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(
node["value"]).startswith("cuda"):
node.copyAttributes(device_node)

model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []

if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [
1, 2
]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)

model.float()

return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]],
context_length: int = 77,
truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]

sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, :len(tokens)] = torch.tensor(tokens)

return result
Loading

0 comments on commit d3467c9

Please sign in to comment.