diff --git a/projects/GLEE/glee/models/glee_model.py b/projects/GLEE/glee/models/glee_model.py index 1d69d36..bf5059b 100644 --- a/projects/GLEE/glee/models/glee_model.py +++ b/projects/GLEE/glee/models/glee_model.py @@ -618,3 +618,92 @@ def get_template(self, imgs, pad_masks, prompt_mode='scribble'): return ref_feats, ref_masks + + + + def get_tracking_contrastive_lossv3(self, video_outputs, video_targets, task): # IDOL track loss + if task in self.no_mask_tasks: + indices_all = self.matcher(video_outputs, video_targets, 'task', cost=["cls", "box"]) + else: + indices_all = self.matcher(video_outputs, video_targets, 'task' ) + + video_len = self.video_info['len'] + track_loss = 0 + num_inst = 0 + + batch_similarity = [] + batch_label = [] + for i in range(self.video_info['bz']): # 每个batch 切片操作 + indices = indices_all[i*video_len:(i+1)*video_len] + bz_embedding = video_outputs['pred_track_embed'][i*video_len:(i+1)*video_len] + bz_target = video_targets[i*video_len:(i+1)*video_len] + zero = torch.tensor(0).to(bz_embedding.device) + one = torch.tensor(1).to(bz_embedding.device) + video_contras = {} + memory = {} + for f,(findice,fembed,ftarget) in enumerate(zip(indices,bz_embedding,bz_target)): + vf_embed_k = fembed[findice[0]] + if len(vf_embed_k.shape) ==1: + vf_embed_k.unsqueeze(0) + vf_gt_id_k = ftarget['inst_id'][findice[1]] + + + # neg sample + sampled_index = set(random.sample(range(300),20)) + neg_index = sampled_index - set(findice[0].tolist()) + neg_index = list(neg_index) + vf_embed_neg = fembed[neg_index] + vf_embed = torch.cat([vf_embed_k,vf_embed_neg],dim=0) + vf_gt_id = torch.cat([vf_gt_id_k,zero.repeat(len(neg_index))-2],dim=0) + + video_contras[f] = (vf_embed,vf_gt_id) + + if f > 0: + num_inst = num_inst + len(ftarget['inst_id']) + similarity_matric = torch.einsum("ac,bc->ab", video_contras[f-1][0], vf_embed_k) #[num_1, num_gt] + + v0_gt_id_m = video_contras[f-1][1].unsqueeze(-1).repeat(1,len(vf_gt_id_k)) + v1_gt_id_m = vf_gt_id_k.unsqueeze(0).repeat(len(video_contras[f-1][1]),1) + similarity_label = (v0_gt_id_m == v1_gt_id_m).float() # can be treat as one hot label + # use focal loss instand of contrastive + # aux cosine + # aux_contrastive_embed=nn.functional.normalize(video_contras[f-1][0].float(),dim=1) + # key_embed_i=nn.functional.normalize(vf_embed_k.float(),dim=1) + # cosine = torch.einsum('nc,kc->nk',[aux_contrastive_embed,key_embed_i]) + + # batch_similarity_aux.append(cosine.flatten() ) + batch_similarity.append(similarity_matric.flatten() ) + batch_label.append(similarity_label.flatten() ) + if len(batch_similarity)==0 or torch.cat(batch_similarity).shape[0] == 0: + track_loss = (video_outputs['pred_track_embed']*0).sum() + else: + contras_loss = 0 + aux_loss = 0 + for pred, label in zip(batch_similarity, batch_label): + if len(pred) == 0: + continue + pred = pred.unsqueeze(0) + label = label.unsqueeze(0) + # aux_pred = aux_pred.unsqueeze(0) + + pos_inds = (label == 1) + neg_inds = (label == 0) + pred_pos = pred * pos_inds.float() + pred_neg = pred * neg_inds.float() + # use -inf to mask out unwanted elements. + pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf') + pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf') + _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1) + _neg_expand = pred_neg.repeat(1, pred.shape[1]) + # [bz,N], N is all pos and negative samples on reference frame, label indicate it's pos or negative + x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0) + contras_loss += torch.logsumexp(x, dim=1) + + + # track_loss = (contras_loss + 1.5*aux_loss) + track_loss = contras_loss/max(num_inst,1) + + track_loss = track_loss # /(self.video_info['bz']) + return track_loss + +