forked from Ahmednull/L2CS-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
112 lines (96 loc) · 4.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
import torch.nn as nn
import os
import scipy.io as sio
import cv2
import math
from math import cos, sin
from pathlib import Path
import subprocess
import re
from model import L2CS
import torchvision
import sys
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
def gazeto3d(gaze):
gaze_gt = np.zeros([3])
gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
gaze_gt[1] = -np.sin(gaze[1])
gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
return gaze_gt
def angular(gaze, label):
total = np.sum(gaze * label)
return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
"""Draw gaze angle on given image with a given eye positions."""
image_out = image_in
(h, w) = image_in.shape[:2]
length = w/2
pos = (int(a+c / 2.0), int(b+d / 2.0))
if len(image_out.shape) == 2 or image_out.shape[2] == 1:
image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
dy = -length * np.sin(pitchyaw[1])
cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
thickness, cv2.LINE_AA, tipLength=0.18)
return image_out
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
cuda = not cpu and torch.cuda.is_available()
if cuda:
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch_size: # check batch_size is divisible by device_count
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
space = ' ' * len(s)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else:
s += 'CPU\n'
return torch.device('cuda:0' if cuda else 'cpu')
def spherical2cartesial(x):
output = torch.zeros(x.size(0),3)
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
output[:,1] = torch.sin(x[:,1])
return output
def compute_angular_error(input,target):
input = spherical2cartesial(input)
target = spherical2cartesial(target)
input = input.view(-1,3,1)
target = target.view(-1,1,3)
output_dot = torch.bmm(target,input)
output_dot = output_dot.view(-1)
output_dot = torch.acos(output_dot)
output_dot = output_dot.data
output_dot = 180*torch.mean(output_dot)/math.pi
return output_dot
def softmax_temperature(tensor, temperature):
result = torch.exp(tensor / temperature)
result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
return result
def git_describe(path=Path(__file__).parent): # path must be a directory
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
s = f'git -C {path} describe --tags --long --always'
try:
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
except subprocess.CalledProcessError as e:
return '' # not a git repository