Skip to content

Commit

Permalink
update track loss for video finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
wjf5203 committed Jul 26, 2024
1 parent 5ad764c commit 5f1832d
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions projects/GLEE/glee/models/glee_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,92 @@ def get_template(self, imgs, pad_masks, prompt_mode='scribble'):

return ref_feats, ref_masks




def get_tracking_contrastive_lossv3(self, video_outputs, video_targets, task): # IDOL track loss
if task in self.no_mask_tasks:
indices_all = self.matcher(video_outputs, video_targets, 'task', cost=["cls", "box"])
else:
indices_all = self.matcher(video_outputs, video_targets, 'task' )

video_len = self.video_info['len']
track_loss = 0
num_inst = 0

batch_similarity = []
batch_label = []
for i in range(self.video_info['bz']): # 每个batch 切片操作
indices = indices_all[i*video_len:(i+1)*video_len]
bz_embedding = video_outputs['pred_track_embed'][i*video_len:(i+1)*video_len]
bz_target = video_targets[i*video_len:(i+1)*video_len]
zero = torch.tensor(0).to(bz_embedding.device)
one = torch.tensor(1).to(bz_embedding.device)
video_contras = {}
memory = {}
for f,(findice,fembed,ftarget) in enumerate(zip(indices,bz_embedding,bz_target)):
vf_embed_k = fembed[findice[0]]
if len(vf_embed_k.shape) ==1:
vf_embed_k.unsqueeze(0)
vf_gt_id_k = ftarget['inst_id'][findice[1]]


# neg sample
sampled_index = set(random.sample(range(300),20))
neg_index = sampled_index - set(findice[0].tolist())
neg_index = list(neg_index)
vf_embed_neg = fembed[neg_index]
vf_embed = torch.cat([vf_embed_k,vf_embed_neg],dim=0)
vf_gt_id = torch.cat([vf_gt_id_k,zero.repeat(len(neg_index))-2],dim=0)

video_contras[f] = (vf_embed,vf_gt_id)

if f > 0:
num_inst = num_inst + len(ftarget['inst_id'])
similarity_matric = torch.einsum("ac,bc->ab", video_contras[f-1][0], vf_embed_k) #[num_1, num_gt]

v0_gt_id_m = video_contras[f-1][1].unsqueeze(-1).repeat(1,len(vf_gt_id_k))
v1_gt_id_m = vf_gt_id_k.unsqueeze(0).repeat(len(video_contras[f-1][1]),1)
similarity_label = (v0_gt_id_m == v1_gt_id_m).float() # can be treat as one hot label
# use focal loss instand of contrastive
# aux cosine
# aux_contrastive_embed=nn.functional.normalize(video_contras[f-1][0].float(),dim=1)
# key_embed_i=nn.functional.normalize(vf_embed_k.float(),dim=1)
# cosine = torch.einsum('nc,kc->nk',[aux_contrastive_embed,key_embed_i])

# batch_similarity_aux.append(cosine.flatten() )
batch_similarity.append(similarity_matric.flatten() )
batch_label.append(similarity_label.flatten() )
if len(batch_similarity)==0 or torch.cat(batch_similarity).shape[0] == 0:
track_loss = (video_outputs['pred_track_embed']*0).sum()
else:
contras_loss = 0
aux_loss = 0
for pred, label in zip(batch_similarity, batch_label):
if len(pred) == 0:
continue
pred = pred.unsqueeze(0)
label = label.unsqueeze(0)
# aux_pred = aux_pred.unsqueeze(0)

pos_inds = (label == 1)
neg_inds = (label == 0)
pred_pos = pred * pos_inds.float()
pred_neg = pred * neg_inds.float()
# use -inf to mask out unwanted elements.
pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
_pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1)
_neg_expand = pred_neg.repeat(1, pred.shape[1])
# [bz,N], N is all pos and negative samples on reference frame, label indicate it's pos or negative
x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0)
contras_loss += torch.logsumexp(x, dim=1)


# track_loss = (contras_loss + 1.5*aux_loss)
track_loss = contras_loss/max(num_inst,1)

track_loss = track_loss # /(self.video_info['bz'])
return track_loss


0 comments on commit 5f1832d

Please sign in to comment.