-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
208 lines (183 loc) · 9.17 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation protocol to compute metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from disentanglement_lib.data.ground_truth import named_data
from disentanglement_lib.evaluation.metrics import beta_vae # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import dci # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import downstream_task # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import factor_vae # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import irs # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import mig # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import modularity_explicitness # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import reduced_downstream_task # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import sap_score # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import unsupervised_metrics # pylint: disable=unused-import
from disentanglement_lib.utils import results
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import gin.tf
from tensorflow.python.framework.errors_impl import NotFoundError
# Some more redundant code, but this allows us to not import utils_pytorch
def get_dataset_name():
"""Reads the name of the dataset from the environment variable `AICROWD_DATASET_NAME`."""
return os.getenv("AICROWD_DATASET_NAME", "cars3d")
def evaluate_with_gin(model_dir,
output_dir,
overwrite=False,
gin_config_files=None,
gin_bindings=None,
eval_pytorch = False):
"""Evaluate a representation based on the provided gin configuration.
This function will set the provided gin bindings, call the evaluate()
function and clear the gin config. Please see the evaluate() for required
gin bindings.
Args:
model_dir: String with path to directory where the representation is saved.
output_dir: String with the path where the evaluation should be saved.
overwrite: Boolean indicating whether to overwrite output directory.
gin_config_files: List of gin config files to load.
gin_bindings: List of gin bindings to use.
"""
if gin_config_files is None:
gin_config_files = []
if gin_bindings is None:
gin_bindings = []
gin.parse_config_files_and_bindings(gin_config_files, gin_bindings)
evaluate(model_dir, output_dir, overwrite, eval_pytorch = eval_pytorch)
gin.clear_config()
@gin.configurable(
"evaluation", blacklist=["model_dir", "output_dir", "overwrite"])
def evaluate(model_dir,
output_dir,
overwrite=False,
evaluation_fn=gin.REQUIRED,
random_seed=gin.REQUIRED,
name="",
eval_pytorch = False):
"""Loads a representation TFHub module and computes disentanglement metrics.
Args:
model_dir: String with path to directory where the representation function
is saved.
output_dir: String with the path where the results should be saved.
overwrite: Boolean indicating whether to overwrite output directory.
evaluation_fn: Function used to evaluate the representation (see metrics/
for examples).
random_seed: Integer with random seed used for training.
name: Optional string with name of the metric (can be used to name metrics).
"""
# We do not use the variable 'name'. Instead, it can be used to name scores
# as it will be part of the saved gin config.
del name
# Delete the output directory if it already exists.
if tf.gfile.IsDirectory(output_dir):
if overwrite:
tf.gfile.DeleteRecursively(output_dir)
else:
raise ValueError("Directory already exists and overwrite is False.")
# Set up time to keep track of elapsed time in results.
experiment_timer = time.time()
try:
# Automatically set the proper data set if necessary. We replace the active
# gin config as this will lead to a valid gin config file where the data set
# is present.
if gin.query_parameter("dataset.name") == "auto":
# Obtain the dataset name from the gin config of the previous step.
gin_config_file = os.path.join(model_dir, "results", "gin",
"postprocess.gin")
gin_dict = results.gin_dict(gin_config_file)
with gin.unlock_config():
gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace(
"'", ""))
dataset = named_data.get_named_ground_truth_data()
except NotFoundError:
# If we did not train with disentanglement_lib, there is no "previous step",
# so we'll have to rely on the environment variable.
if gin.query_parameter("dataset.name") == "auto":
with gin.unlock_config():
gin.bind_parameter("dataset.name", get_dataset_name())
dataset = named_data.get_named_ground_truth_data()
eval_tf = True
if eval_pytorch and os.path.exists(os.path.join(model_dir, 'pytorch_model.pt')):
eval_tf = False
if os.path.exists(os.path.join(model_dir, 'tfhub')) and eval_tf:
# Path to TFHub module of previously trained representation.
module_path = os.path.join(model_dir, "tfhub")
# Evaluate results with tensorflow
results_dict = _evaluate_with_tensorflow(module_path, evaluation_fn,
dataset, random_seed)
elif os.path.exists(os.path.join(model_dir, 'pytorch_model.pt')):
# Path to Pytorch JIT Module of previously trained representation.
module_path = os.path.join(model_dir, 'pytorch_model.pt')
# Evaluate results with pytorch
results_dict = _evaluate_with_pytorch(module_path, evaluation_fn,
dataset, random_seed)
elif os.path.exists(os.path.join(model_dir, 'python_model.dill')):
# Path to the dilled function
module_path = os.path.join(model_dir, 'python_model.dill')
# Evaluate results with numpy
results_dict = _evaluate_with_numpy(module_path, evaluation_fn,
dataset, random_seed)
else:
raise RuntimeError("`model_dir` must contain either a pytorch or a TFHub model.")
# Save the results (and all previous results in the pipeline) on disk.
original_results_dir = os.path.join(model_dir, "results")
results_dir = os.path.join(output_dir, "results")
results_dict["elapsed_time"] = time.time() - experiment_timer
results.update_result_directory(results_dir, "evaluation", results_dict,
original_results_dir)
def _evaluate_with_tensorflow(module_path, evaluation_fn, dataset, random_seed):
with hub.eval_function_for_module(module_path) as f:
def _representation_function(x):
"""Computes representation vector for input images."""
output = f(dict(images=x), signature="representation", as_dict=True)
return np.array(output["default"])
# Computes scores of the representation based on the evaluation_fn.
results_dict = evaluation_fn(
dataset,
_representation_function,
random_state=np.random.RandomState(random_seed))
return results_dict
def _evaluate_with_pytorch(module_path, evalulation_fn, dataset, random_seed):
import utils_pytorch
# Load model and make a representor
model = utils_pytorch.import_model(path=module_path)
_representation_function = utils_pytorch.make_representor(model)
# Evaluate score with the evaluation_fn
results_dict = evalulation_fn(
dataset,
_representation_function,
random_state=np.random.RandomState(random_seed)
)
# Easy peasy lemon squeezy
return results_dict
def _evaluate_with_numpy(module_path, evalulation_fn, dataset, random_seed):
import utils_numpy
# Load function and make a representor
fn = utils_numpy.import_function(path=module_path)
_representation_function = utils_numpy.make_representor(fn)
# Evaluate score with the evaluation_fn
results_dict = evalulation_fn(
dataset,
_representation_function,
random_state=np.random.RandomState(random_seed)
)
# Easy peasy lemon squeezy
return results_dict