-
Notifications
You must be signed in to change notification settings - Fork 18
/
semifake.py
108 lines (72 loc) · 3.77 KB
/
semifake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from importlib import import_module
import numpy as np
import DeepFried2 as df
from scipy.spatial.distance import cdist
from lbtoolbox.util import batched
import lib
from lib.models import add_defaults
from fakenews import FakeNeuralNewsNetwork
DIST_THRESH = 7
class SemiFakeNews:
def __init__(self, model, weights, input_scale_factor, fake_shape, fake_dets, debug_skip_full_image=False):
self.input_scale_factor = input_scale_factor
mod = import_module('lib.models.' + model)
self.net = mod.mknet()
add_defaults(self.net)
try:
self.net.load(weights)
except ValueError:
print("!!!!!!!THE WEIGHTS YOU LOADED DON'T BELONG TO THE MODEL YOU'RE USING!!!!!!")
raise
# Shares the weights, just replaces the avg-pooling layer.
self.net_hires = mod.hires_shared_twin(self.net)
add_defaults(self.net_hires)
self.net.evaluate()
self.net_hires.evaluate()
print("Precompiling network... 1/2", end='', flush=True)
#self.net.forward(np.zeros((1,3) + self.net.in_shape, df.floatX))
print("\rPrecompiling network... 2/2", end='', flush=True)
#if not (debug_skip_full_image and fake_dets is None):
#out = self.net_hires.forward(np.zeros((1,3,1080//2,1920//2), df.floatX))
print(" Done", flush=True)
#fake_shape = out.shape[2:] # We didn't fake the avg-pool effect yet, so don't!
self.fake = FakeNeuralNewsNetwork(fake_dets, shape=fake_shape) if fake_dets is not None else None
def _scale_input_shape(self, shape):
return lib.scale_shape(shape, self.input_scale_factor)
# Only for fake
def tick(self, *a, **kw):
if self.fake is not None:
self.fake.tick(*a, **kw)
# Only for fake
def fake_camera(self, *a, **kw):
if self.fake is not None:
self.fake.fake_camera(*a, **kw)
def embed_crops(self, crops, *fakea, batchsize=32, **fakekw):
assert all(self._scale_input_shape(crop.shape) == self.net.in_shape for crop in crops)
X = np.array([lib.img2df(crop, shape=self.net.in_shape) for crop in crops])
out = np.concatenate([self.net.forward(Xb) for Xb in batched(batchsize, X)])
return out[:,:,0,0] # Output is Dx1x1
def embeddings_cdist(self, embsA, embsB):
return cdist(embsA, embsB)
#@profile
def embed_images(self, images, batch=True):
# TODO: batch=False
X = np.array([lib.img2df(img, shape=self._scale_input_shape(img.shape)) for img in images])
return self.net_hires.forward(X)
def search_person(self, img_embs, person_emb, *fakea, **fakekw):
# compute distance between embeddings and person's embedding.
return np.sqrt(np.sum((img_embs - person_emb[:,None,None])**2, axis=0))
#d[d > DIST_THRESH] = 9999 # Will go to zero/uniform in the softmin
# Convert distance to probability.
#return lib.softmin(d, T), d # TODO: Might be better to fit a sigmoid or something.
def fix_shape(self, net_output, orig_shape, out_shape, fill_value=0):
orig_shape = self._scale_input_shape(orig_shape)
# Scale to `out_shape` but keeping correct aspect ratio.
h = net_output.shape[0]*self.net.scale_factor[0] /orig_shape[0]*out_shape[0]
w = net_output.shape[1]*self.net.scale_factor[1] /orig_shape[1]*out_shape[1]
return lib.paste_into_middle_2d(lib.resize_map(net_output, (int(h), int(w))), out_shape, fill_value)
# THIS IS THE ONLY THING FAKE :(
# TODO: Make semi-fake, by clearing out known_embs.
def personness(self, image, known_embs, return_pose=False):
assert self.fake is not None, "The world doesn't work that way my friend!"
return self.fake.personness(image, known_embs, return_pose)