Skip to content

Commit

Permalink
[ControlNet][MLSD] Fix F.interpolate when align_corners=True
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuullll committed Mar 17, 2024
1 parent ce80c96 commit f37d560
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 52 deletions.
52 changes: 0 additions & 52 deletions ipex_hijack/controlnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,6 @@ def override_annotator_model_device(orig_func, self, *args, **kwargs):
self.device = devices.cpu


# Adapted from https://github.com/Mikubill/sd-webui-controlnet/blob/4cf15d1c9c565b8d0c5f782a89c5a6286dc6e6ff/annotator/mlsd/utils.py#L48
@hijack_message("Offloading pred_lines() to cpu")
def pred_lines_cpu(image, model,
input_shape=[512, 512],
score_thr=0.10,
dist_thr=20.0):
from annotator.mlsd.utils import np, cv2, torch, deccode_output_score_and_ptss
h, w, _ = image.shape
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]

resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)

resized_image = resized_image.transpose((2,0,1))
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
batch_image = (batch_image / 127.5) - 1.0

batch_image = torch.from_numpy(batch_image).float().to(devices.cpu)
model = model.to(devices.cpu)
outputs = model(batch_image).to(devices.cpu)
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
start = vmap[:, :, :2]
end = vmap[:, :, 2:]
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))

segments_list = []
for center, score in zip(pts, pts_score):
y, x = center
distance = dist_map[y, x]
if score > score_thr and distance > dist_thr:
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
x_start = x + disp_x_start
y_start = y + disp_y_start
x_end = x + disp_x_end
y_end = y + disp_y_end
segments_list.append([x_start, y_start, x_end, y_end])

lines = 2 * np.array(segments_list) # 256 > 512
lines[:, 0] = lines[:, 0] * w_ratio
lines[:, 1] = lines[:, 1] * h_ratio
lines[:, 2] = lines[:, 2] * w_ratio
lines[:, 3] = lines[:, 3] * h_ratio

return lines


def apply_controlnet_hijacks():
def is_controlnet_device_xpu(*args, **kwargs):
return devices.get_device_for("controlnet").type == "xpu"
Expand All @@ -72,12 +26,6 @@ def is_controlnet_device_xpu(*args, **kwargs):
is_controlnet_device_xpu,
)

CondFunc(
'annotator.mlsd.pred_lines', # mlsd
lambda _, *args, **kwargs: pred_lines_cpu(*args, **kwargs),
is_controlnet_device_xpu,
)

CondFunc(
'annotator.uniformer.inference_segmentor', # seg_ufade20k
lambda orig_func, model, *args, **kwargs: orig_func(model.to(devices.cpu), *args, **kwargs),
Expand Down
6 changes: 6 additions & 0 deletions scripts/ipex_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def apply_general_hijacks():
CondFunc('torch.nn.functional.batch_norm',
lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: orig_func(input.half(), running_mean.half(), running_var.half(), weight=weight.half() if weight is not None else None, bias=bias.half() if bias is not None else None, training=training, momentum=momentum, eps=eps).to(input.dtype),
lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: input.device.type == 'xpu' and input.dtype == torch.float)

# IPEX: incorrect interpolate result with XPU when align_corner=True, move to cpu instead
# TODO: file an issue to IPEX
CondFunc('torch.nn.functional.interpolate',
lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: orig_func(input.cpu(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).to(input.device),
lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: input.device.type == 'xpu' and align_corners)

log("Registered hijacks for IPEX")

Expand Down

0 comments on commit f37d560

Please sign in to comment.