diff --git a/args.py b/args.py index de7a279..63684ea 100644 --- a/args.py +++ b/args.py @@ -18,6 +18,7 @@ def convert_to_byte(size): parser.add_argument('--output_dir', type=str) parser.add_argument('--output_webcam', type=str) parser.add_argument('--output_size', type=str, default='256x256') +parser.add_argument('--model', type=str, default='standard_float') parser.add_argument('--debug_input', action='store_true') parser.add_argument('--mouse_input', type=str) parser.add_argument('--perf', type=str) diff --git a/main.py b/main.py index 07cb023..d85b4c4 100644 --- a/main.py +++ b/main.py @@ -230,8 +230,9 @@ def run(self): model = model print("Pretrained Model Loaded") - mouth_eye_vector = torch.empty(1, 27) - pose_vector = torch.empty(1, 6) + + mouth_eye_vector = torch.empty(1, 27,dtype=torch.half if args.model.endswith('half') else torch.float) + pose_vector = torch.empty(1, 6,dtype=torch.half if args.model.endswith('half') else torch.float) input_image = self.input_image.to(device) mouth_eye_vector = mouth_eye_vector.to(device) @@ -417,7 +418,12 @@ def main(): y = i // IMG_WIDTH x = i % IMG_WIDTH img.putpixel((x, y), (0, 0, 0, 0)) - input_image = preprocessing_image(img.crop((0, 0, IMG_WIDTH, IMG_WIDTH))).unsqueeze(0) + input_image = preprocessing_image(img.crop((0, 0, IMG_WIDTH, IMG_WIDTH))) + if args.model.endswith('half'): + input_image = torch.from_numpy(input_image).half() * 2.0 - 1 + else: + input_image = torch.from_numpy(input_image).float() * 2.0 - 1 + input_image=input_image.unsqueeze(0) extra_image = None if img.size[1] > IMG_WIDTH: extra_image = np.array(img.crop((0, IMG_WIDTH, img.size[0], img.size[1]))) diff --git a/models.py b/models.py index c2e0f32..2f24a4c 100644 --- a/models.py +++ b/models.py @@ -4,6 +4,9 @@ import tha2.poser.modes.mode_20 import tha3.poser.modes.standard_float +import tha3.poser.modes.separable_float +import tha3.poser.modes.standard_half +import tha3.poser.modes.separable_half from torch.nn.functional import interpolate from args import args @@ -55,9 +58,24 @@ def forward(self, image, mouth_eye_vector, pose_vector, mouth_eye_vector_c, rati class TalkingAnime3(nn.Module): def __init__(self): super(TalkingAnime3, self).__init__() - self.face_morpher = tha3.poser.modes.standard_float.load_face_morpher('data/models/standard_float/face_morpher.pt') - self.two_algo_face_body_rotator = tha3.poser.modes.standard_float.load_two_algo_generator('data/models/standard_float/two_algo_face_body_rotator.pt') - self.editor = tha3.poser.modes.standard_float.load_editor('data/models/standard_float/editor.pt') + if args.model == "standard_float": + self.face_morpher = tha3.poser.modes.standard_float.load_face_morpher('data/models/standard_float/face_morpher.pt') + self.two_algo_face_body_rotator = tha3.poser.modes.standard_float.load_two_algo_generator('data/models/standard_float/two_algo_face_body_rotator.pt') + self.editor = tha3.poser.modes.standard_float.load_editor('data/models/standard_float/editor.pt') + elif args.model == "standard_half": + self.face_morpher = tha3.poser.modes.standard_half.load_face_morpher('data/models/standard_half/face_morpher.pt') + self.two_algo_face_body_rotator = tha3.poser.modes.standard_half.load_two_algo_generator('data/models/standard_half/two_algo_face_body_rotator.pt') + self.editor = tha3.poser.modes.standard_half.load_editor('data/models/standard_half/editor.pt') + elif args.model == "separable_float": + self.face_morpher = tha3.poser.modes.separable_float.load_face_morpher('data/models/separable_float/face_morpher.pt') + self.two_algo_face_body_rotator = tha3.poser.modes.separable_float.load_two_algo_generator('data/models/separable_float/two_algo_face_body_rotator.pt') + self.editor = tha3.poser.modes.separable_float.load_editor('data/models/separable_float/editor.pt') + elif args.model == "separable_half": + self.face_morpher = tha3.poser.modes.separable_half.load_face_morpher('data/models/separable_half/face_morpher.pt') + self.two_algo_face_body_rotator = tha3.poser.modes.separable_half.load_two_algo_generator('data/models/separable_half/two_algo_face_body_rotator.pt') + self.editor = tha3.poser.modes.separable_half.load_editor('data/models/separable_half/editor.pt') + else: + raise RuntimeError("Invalid model: '%s'" % args.model) self.face_cache = OrderedDict() self.tot = 0 self.hit = 0 diff --git a/utils.py b/utils.py index c9c214c..30bf9fe 100644 --- a/utils.py +++ b/utils.py @@ -43,8 +43,7 @@ def preprocessing_image(image): if pixel[3] == 0.0: pixel[0:3] = 0.0 reshaped_image = linear_image.transpose().reshape(c, h, w) - torch_image = torch.from_numpy(reshaped_image).float() * 2.0 - 1 - return torch_image + return reshaped_image def postprocessing_image(tensor):