Skip to content

Commit

Permalink
model arg
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyuyzl committed Jul 27, 2022
1 parent a1d43e7 commit 7026da2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
1 change: 1 addition & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def convert_to_byte(size):
parser.add_argument('--output_dir', type=str)
parser.add_argument('--output_webcam', type=str)
parser.add_argument('--output_size', type=str, default='256x256')
parser.add_argument('--model', type=str, default='standard_float')
parser.add_argument('--debug_input', action='store_true')
parser.add_argument('--mouse_input', type=str)
parser.add_argument('--perf', type=str)
Expand Down
12 changes: 9 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def run(self):
model = model
print("Pretrained Model Loaded")

mouth_eye_vector = torch.empty(1, 27)
pose_vector = torch.empty(1, 6)

mouth_eye_vector = torch.empty(1, 27,dtype=torch.half if args.model.endswith('half') else torch.float)
pose_vector = torch.empty(1, 6,dtype=torch.half if args.model.endswith('half') else torch.float)

input_image = self.input_image.to(device)
mouth_eye_vector = mouth_eye_vector.to(device)
Expand Down Expand Up @@ -417,7 +418,12 @@ def main():
y = i // IMG_WIDTH
x = i % IMG_WIDTH
img.putpixel((x, y), (0, 0, 0, 0))
input_image = preprocessing_image(img.crop((0, 0, IMG_WIDTH, IMG_WIDTH))).unsqueeze(0)
input_image = preprocessing_image(img.crop((0, 0, IMG_WIDTH, IMG_WIDTH)))
if args.model.endswith('half'):
input_image = torch.from_numpy(input_image).half() * 2.0 - 1
else:
input_image = torch.from_numpy(input_image).float() * 2.0 - 1
input_image=input_image.unsqueeze(0)
extra_image = None
if img.size[1] > IMG_WIDTH:
extra_image = np.array(img.crop((0, IMG_WIDTH, img.size[0], img.size[1])))
Expand Down
24 changes: 21 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import tha2.poser.modes.mode_20
import tha3.poser.modes.standard_float
import tha3.poser.modes.separable_float
import tha3.poser.modes.standard_half
import tha3.poser.modes.separable_half
from torch.nn.functional import interpolate

from args import args
Expand Down Expand Up @@ -55,9 +58,24 @@ def forward(self, image, mouth_eye_vector, pose_vector, mouth_eye_vector_c, rati
class TalkingAnime3(nn.Module):
def __init__(self):
super(TalkingAnime3, self).__init__()
self.face_morpher = tha3.poser.modes.standard_float.load_face_morpher('data/models/standard_float/face_morpher.pt')
self.two_algo_face_body_rotator = tha3.poser.modes.standard_float.load_two_algo_generator('data/models/standard_float/two_algo_face_body_rotator.pt')
self.editor = tha3.poser.modes.standard_float.load_editor('data/models/standard_float/editor.pt')
if args.model == "standard_float":
self.face_morpher = tha3.poser.modes.standard_float.load_face_morpher('data/models/standard_float/face_morpher.pt')
self.two_algo_face_body_rotator = tha3.poser.modes.standard_float.load_two_algo_generator('data/models/standard_float/two_algo_face_body_rotator.pt')
self.editor = tha3.poser.modes.standard_float.load_editor('data/models/standard_float/editor.pt')
elif args.model == "standard_half":
self.face_morpher = tha3.poser.modes.standard_half.load_face_morpher('data/models/standard_half/face_morpher.pt')
self.two_algo_face_body_rotator = tha3.poser.modes.standard_half.load_two_algo_generator('data/models/standard_half/two_algo_face_body_rotator.pt')
self.editor = tha3.poser.modes.standard_half.load_editor('data/models/standard_half/editor.pt')
elif args.model == "separable_float":
self.face_morpher = tha3.poser.modes.separable_float.load_face_morpher('data/models/separable_float/face_morpher.pt')
self.two_algo_face_body_rotator = tha3.poser.modes.separable_float.load_two_algo_generator('data/models/separable_float/two_algo_face_body_rotator.pt')
self.editor = tha3.poser.modes.separable_float.load_editor('data/models/separable_float/editor.pt')
elif args.model == "separable_half":
self.face_morpher = tha3.poser.modes.separable_half.load_face_morpher('data/models/separable_half/face_morpher.pt')
self.two_algo_face_body_rotator = tha3.poser.modes.separable_half.load_two_algo_generator('data/models/separable_half/two_algo_face_body_rotator.pt')
self.editor = tha3.poser.modes.separable_half.load_editor('data/models/separable_half/editor.pt')
else:
raise RuntimeError("Invalid model: '%s'" % args.model)
self.face_cache = OrderedDict()
self.tot = 0
self.hit = 0
Expand Down
3 changes: 1 addition & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def preprocessing_image(image):
if pixel[3] == 0.0:
pixel[0:3] = 0.0
reshaped_image = linear_image.transpose().reshape(c, h, w)
torch_image = torch.from_numpy(reshaped_image).float() * 2.0 - 1
return torch_image
return reshaped_image


def postprocessing_image(tensor):
Expand Down

0 comments on commit 7026da2

Please sign in to comment.