-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
129 lines (101 loc) · 4.68 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Author: Haoran Chen
Date: 2022.08.15
"""
import torch
from clip import clip
from torch import nn
from einops import rearrange
class PromptGenerator(nn.Module):
def __init__(self, classnames, clip_model, source_name, target_name, args):
super().__init__()
n_cls = len(classnames)
dtype = torch.float32
embedding_dim = clip_model.ln_final.weight.shape[0]
ctx_cls_vectors = torch.empty(n_cls, args.M1, embedding_dim, requires_grad=True, dtype=dtype,
device=args.device)
ctx_source_vectors = torch.empty(1, args.M2, embedding_dim, requires_grad=True, dtype=dtype, device=args.device)
ctx_target_vectors = torch.empty(1, args.M2, embedding_dim, requires_grad=True, dtype=dtype, device=args.device)
nn.init.normal_(ctx_cls_vectors, std=0.02)
nn.init.normal_(ctx_source_vectors, std=0.02)
nn.init.normal_(ctx_target_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * (args.M1 + args.M2))
self.ctx_cls = nn.Parameter(ctx_cls_vectors) # to be optimized
self.ctx_source = nn.Parameter(ctx_source_vectors) # to be optimized
self.ctx_target = nn.Parameter(ctx_target_vectors) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(args.device)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOT
self.register_buffer("token_suffix", embedding[:, 1 + args.M1 + args.M2:, :]) # CLS, EOT
self.n_cls = n_cls
self.tokenized_prompts = tokenized_prompts # torch.Tensor
def forward(self):
ctx_cls = self.ctx_cls
ctx_source = self.ctx_source
ctx_target = self.ctx_target
prefix = self.token_prefix
suffix = self.token_suffix
source_prompts = torch.cat(
[prefix, # (n_cls, 1, dim)
ctx_cls, # (n_cls, M1, dim)
ctx_source.repeat(self.n_cls, 1, 1), # (n_cls, 1, dim)
suffix, # (n_cls, *, dim)
],
dim=1)
target_prompts = torch.cat(
[prefix, # (n_cls, 1, dim)
ctx_cls,
ctx_target.repeat(self.n_cls, 1, 1), # (n_cls, 1, dim)
suffix, # (n_cls, *, dim)
],
dim=1)
prompts = torch.cat([source_prompts, target_prompts], dim=0)
return prompts
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.visual.conv1.weight.dtype
def forward(self, prompts, tokenized_prompts):
prompts = prompts.type(self.dtype)
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)
x = self.ln_final(x).type(self.dtype)
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class Custom_Clip(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
def forward(self, image, prompt, tokenized_prompts):
image_features = self.image_encoder(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = self.text_encoder(prompt, tokenized_prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return image_features, text_features
class AutoEncoder(nn.Module):
def __init__(self, dim, decoder_dim, inner_dim):
super().__init__()
self.prompt_w1 = nn.Linear(dim, inner_dim)
self.prompt_w2 = nn.Linear(inner_dim, decoder_dim)
self.prompt_w3 = nn.Linear(decoder_dim, dim)
def forward(self, x):
batch = x.size()[0]
x = rearrange(x, 'b t e -> (b t) e')
x = x.to(self.prompt_w1.weight.device)
x = self.prompt_w1(x)
x = torch.tanh(self.prompt_w2(x))
x = self.prompt_w3(x)
x = rearrange(x, '(b t) e -> b t e', b=batch)
return x