-
Notifications
You must be signed in to change notification settings - Fork 0
/
tfhub.py
145 lines (118 loc) · 4.18 KB
/
tfhub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import tensorflow as tf
import tensorflow_hub as hub
module_url = "https://tfhub.dev/google/nnlm-en-dim128/2"
embed = hub.KerasLayer(module_url)
embeddings = embed(["A long sentence.", "single-word",
"http://example.com"])
print(embeddings.shape) #(3,128)
#@title Import the necessary modules
# TensorFlow and TF-Hub modules.
from absl import logging
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow_docs.vis import embed
logging.set_verbosity(logging.ERROR)
# Some modules to help with reading the UCF101 dataset.
import random
import re
import os
import tempfile
import cv2
import numpy as np
# Some modules to display an animation using imageio.
import imageio
from IPython import display
from urllib import request # requires python3
#@title Helper functions for the UCF101 dataset
# Utilities to fetch videos from UCF101 dataset
UCF_ROOT = "http://crcv.ucf.edu/THUMOS14/UCF101/UCF101/"
_VIDEO_LIST = None
_CACHE_DIR = tempfile.mkdtemp()
def list_ucf_videos():
"""Lists videos available in UCF101 dataset."""
global _VIDEO_LIST
if not _VIDEO_LIST:
index = request.urlopen(UCF_ROOT).read().decode("utf-8")
videos = re.findall("(v_[\w_]+\.avi)", index)
_VIDEO_LIST = sorted(set(videos))
return list(_VIDEO_LIST)
def fetch_ucf_video(video):
"""Fetchs a video and cache into local filesystem."""
cache_path = os.path.join(_CACHE_DIR, video)
if not os.path.exists(cache_path):
urlpath = request.urljoin(UCF_ROOT, video)
print("Fetching %s => %s" % (urlpath, cache_path))
data = request.urlopen(urlpath).read()
open(cache_path, "wb").write(data)
return cache_path
# Utilities to open video files using CV2
def crop_center_square(frame):
y, x = frame.shape[0:2]
min_dim = min(y, x)
start_x = (x // 2) - (min_dim // 2)
start_y = (y // 2) - (min_dim // 2)
return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]
def load_video(path, max_frames=0, resize=(224, 224)):
cap = cv2.VideoCapture(path)
frames = []
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame = crop_center_square(frame)
frame = cv2.resize(frame, resize)
frame = frame[:, :, [2, 1, 0]]
frames.append(frame)
if len(frames) == max_frames:
break
finally:
cap.release()
return np.array(frames) / 255.0
def to_gif(images):
converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
imageio.mimsave('./animation.gif', converted_images, fps=25)
return embed.embed_file('./animation.gif')
#@title Get the kinetics-400 labels
# Get the kinetics-400 action labels from the GitHub repository.
KINETICS_URL = "https://raw.githubusercontent.com/deepmind/kinetics-i3d/master/data/label_map.txt"
with request.urlopen(KINETICS_URL) as obj:
labels = [line.decode("utf-8").strip() for line in obj.readlines()]
print("Found %d labels." % len(labels))
#
# # Get the list of videos in the dataset.
# ucf_videos = list_ucf_videos()
#
# categories = {}
# for video in ucf_videos:
# category = video[2:-12]
# if category not in categories:
# categories[category] = []
# categories[category].append(video)
# print("Found %d videos in %d categories." % (len(ucf_videos), len(categories)))
#
# for category, sequences in categories.items():
# summary = ", ".join(sequences[:2])
# print("%-20s %4d videos (%s, ...)" % (category, len(sequences), summary))
#
# # Get a sample cricket video.
# video_path = fetch_ucf_video("v_CricketShot_g04_c02.avi")
# sample_video = load_video(video_path)
#
# print(sample_video.shape)
i3d = hub.load("https://tfhub.dev/deepmind/i3d-kinetics-400/1").signatures['default']
def predict(sample_video):
# Add a batch axis to the to the sample video.
model_input = tf.constant(sample_video, dtype=tf.float32)[tf.newaxis, ...]
logits = i3d(model_input)['default'][0]
probabilities = tf.nn.softmax(logits)
print("Top 5 actions:")
for i in np.argsort(probabilities)[::-1][:5]:
print(f" {labels[i]:22}: {probabilities[i] * 100:5.2f}%")
return labels[np.argsort(probabilities)[::-1][:5][0]]
import os
for vf in os.listdir():
if vf.endswith('.ts'):
sample_video = load_video(vf)
best = predict(sample_video)
print(vf, best)