-
Notifications
You must be signed in to change notification settings - Fork 2
/
supervised.py
127 lines (107 loc) · 5.99 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import json
import argparse
from collections import OrderedDict
import torch
from src.utils import bool_flag, initialize_exp
from src.models import build_model
from src.trainer import Trainer
from src.evaluation import Evaluator
#VALIDATION_METRIC = 'precision_at_1-nn'
#VALIDATION_METRIC = 'precision_at_1-csls_knn_10_{}'
VALIDATION_METRIC = 'mean_cosine-csls_knn_10-S2T-10000_{}'
# default criterion in the MUSE code: 'precision_at_1-csls_knn_10'
# main
parser = argparse.ArgumentParser(description='Supervised training')
parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)")
# data
parser.add_argument("--src_lang", type=str, default='en', help="Source language")
parser.add_argument("--tgt_lang", type=str, default='es', help="Target language")
parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
#mapping
parser.add_argument("--map_id_init", type=bool_flag, default=True, help="Initialize the mapping as an identity matrix")
# training refinement
parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)")
parser.add_argument("--generalized", type=bool_flag, default=False, help="Use GPA")
parser.add_argument("--fine_tuning", type=int, default=0, help="Number of fine-tuning iterations (0 to disable); subtracted from n_refinement")
# dictionary creation parameters (for refinement)
parser.add_argument("--dico_train", type=str, default="default", help="Path to training dictionary (default: use identical character strings)")
parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary")
parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)")
parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S")
parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation")
parser.add_argument("--dico_max_rank", type=int, default=10000, help="Maximum dictionary words rank (0 to disable)")
parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)")
parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)")
# reload pre-trained embeddings
parser.add_argument("--src_emb", type=str, default='', help="Reload source embeddings")
parser.add_argument("--tgt_emb", type=str, default='', help="Reload target embeddings")
parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")
# parse parameters
params = parser.parse_args()
params.tgt_lang = params.tgt_lang.strip().split(' ')
params.tgt_emb = params.tgt_emb.strip().split(' ')
# check parameters
assert not params.cuda or torch.cuda.is_available()
assert params.dico_train in ["identical_char", "default","identical_num"] or os.path.isfile(params.dico_train)
assert params.dico_build in ["S2T", "T2S", "S2T|T2S", "S2T&T2S"]
assert params.dico_max_size == 0 or params.dico_max_size < params.dico_max_rank
assert params.dico_max_size == 0 or params.dico_max_size > params.dico_min_size
assert os.path.isfile(params.src_emb)
assert all(os.path.isfile(emb) for emb in params.tgt_emb)
#assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval)
assert params.export in ["", "txt", "pth"]
assert len(params.tgt_lang) == len(params.tgt_emb)
assert len(params.tgt_lang) == 1 or params.generalized
assert params.fine_tuning <= params.n_refinement
# build logger / model / trainer / evaluator
logger = initialize_exp(params)
src_emb, tgt_emb, mapping, _ = build_model(params, False)
trainer = Trainer(src_emb, tgt_emb, mapping, None, params)
evaluator = Evaluator(trainer)
#start generalized training with support
support = True if params.generalized else False
# load a training dictionary. if a dictionary path is not provided, use a default
# one ("default") or create one based on identical character strings ("identical_char")
trainer.load_training_dico(params.dico_train, support)
"""
Learning loop for Procrustes Iterative Learning
"""
for n_iter in range(params.n_refinement + 1):
if n_iter > params.n_refinement - params.fine_tuning:
support = False
logger.info('Starting iteration %i...' % n_iter)
# build a dictionary from aligned embeddings (unless
# it is the first iteration and we use the init one)
if n_iter > 0 or not hasattr(trainer, 'dico'):
trainer.build_dictionary(support)
# apply the Procrustes solution
if params.generalized:
trainer.generalized_procrustes(support,n_iter==0)
else:
trainer.simple_procrustes()
# embeddings evaluation
to_log = OrderedDict({'n_iter': n_iter})
biling_dict = True
evaluator.all_eval(to_log, biling_dict)
# JSON log / save best model / end of epoch
logger.info("__log__:%s" % json.dumps(to_log))
trainer.save_best(to_log, VALIDATION_METRIC.format(params.tgt_lang[-1]))
logger.info('End of iteration %i.\n\n' % n_iter)
# export embeddings
if params.export:
trainer.reload_best()
trainer.export()