-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsampling_cfoldingdiff.py
84 lines (63 loc) · 3.05 KB
/
sampling_cfoldingdiff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import os.path as osp
from params import create_parser
import warnings
warnings.filterwarnings('ignore')
import torch.backends.cudnn as cudnn
import random
import numpy as np
from main import Exp
from tqdm import tqdm
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(111)
# from constants import method_maps
from utils import Recorder
from utils.nerf import TorchNERFBuilder
from utils import *
from utils.angles_and_coords import (
canonical_distances_and_dihedrals,
EXHAUSTIVE_ANGLES,
EXHAUSTIVE_DISTS,
extract_backbone_coords,
create_new_chain_nerf
)
import pandas as pd
import json
if __name__ == '__main__':
args = create_parser()
args.method = 'CFoldingDiff'
config = args.__dict__
default_params = load_config(osp.join('./configs', args.method + '.py' if args.config_file is None else args.config_file))
config.update(default_params)
config['batch_size'] = 1000
config['strict_test'] = True
config["sampling"] = True
mode = config['mode']
svpath = "/gaozhangyang/experiments/DiffSDS/results/CFoldingDiff_sampling"
check_dir(svpath)
exp = Exp(args, distributed=False)
params = torch.load('/gaozhangyang/experiments/DiffSDS/model_zoom/CFoldingDiff/checkpoint.pth', map_location=torch.device('cuda:0'))
new_params = {}
for key, val in params.items():
new_params[key.replace("module.", "")] = val
exp.method.model.load_state_dict(new_params)
print('>>>>>>>>>>>>>>>>>>>>>>>>>> sampling <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
# 需要修改load_data.py里面的## only for sampling_foldingdiff, match duaspace 部分
for batch in tqdm(exp.test_loader):
angles, coords, attn_mask, position_ids, timestamps, seqs, unknown_mask, start_idx, end_idx = cuda([batch["angles"], batch['coords'], batch["attn_mask"], batch["position_ids"], batch["t"], batch["seqs"], batch["unknown_mask"], batch["start_idx"], batch["end_idx"]], device=exp.method.model.device)
raw_coords = coords.clone()
timestamps = 1000
angles, step_angle_loss = exp.method.sampling(angles, coords, attn_mask, position_ids, timestamps, seqs, unknown_mask, start_idx, end_idx , exp.train_loader.dataset)
unknown_mask = unknown_mask.squeeze()
phi, psi, omega, tau, CA_C_1N, C_1N_1CA = torch.split(angles, 1, dim=-1)
pred_coords = coords*(~unknown_mask[...,None,None])
pred_coords2 = exp.method.model.pred_coord(pred_coords.clone(), start_idx, end_idx, phi, psi, omega, C_1N_1CA, tau, CA_C_1N)
for i in range(angles.shape[0]):
if (attn_mask*unknown_mask)[i].sum()>0:
TorchNERFBuilder.sv2pdb(f"{svpath}/pred_{batch['key'][i]}.pdb", pred_coords2.cpu()[i].reshape(-1,3), unknown_mask[i], attn_mask[i])
TorchNERFBuilder.sv2pdb(f"{svpath}/raw_{batch['key'][i]}.pdb", raw_coords.cpu()[i].reshape(-1,3), unknown_mask[i], attn_mask[i])