-
Notifications
You must be signed in to change notification settings - Fork 43
/
combine_csvs.py
114 lines (99 loc) · 4.46 KB
/
combine_csvs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2023 The tpu_graphs Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script joins most-recent .csv files generated by all collection baselines.
The final produced CSV file can be submitted to Kaggle.
"""
from collections.abc import Sequence
import csv
import glob
import os
import sys
from absl import app
from absl import flags
csv.field_size_limit(sys.maxsize)
_OUTPUT_CSV = flags.DEFINE_string(
'output', '~/out/tpugraphs_submission.csv', 'Path to output CSV')
_TILE_CSV = flags.DEFINE_string(
'tile', '', 'Path to csv containing top indices for the tile collection. '
'This can be generated using binary "baselines/tiles/tiles_train.py". If '
'not given, it defaults to the largest timestamp matching '
'~/out/tpugraphs_tiles/results_*.csv')
_LAYOUT_NLP_RANDOM_CSV = flags.DEFINE_string(
'layout_nlp_random', '',
'Path to csv containing ranked indices for the collection '
'layout:nlp:random. This can be generated using binary '
'"baselines/layout/layout_train.py --source nlp --search random". If not '
'given, it defaults to the largest timestamp matching '
'~/out/tpugraphs_layout/results_*_nlp_random.csv')
_LAYOUT_NLP_DEFAULT_CSV = flags.DEFINE_string(
'layout_nlp_default', '',
'Path to csv containing ranked indices for the collection '
'layout:nlp:default. This can be generated using binary '
'"baselines/layout/layout_train.py --source nlp --search default". If not '
'given, it defaults to the largest timestamp matching '
'~/out/tpugraphs_layout/results_*_nlp_default.csv')
_LAYOUT_XLA_RANDOM_CSV = flags.DEFINE_string(
'layout_xla_random', '',
'Path to csv containing ranked indices for the collection '
'layout:xla:random. This can be generated using binary '
'"baselines/layout/layout_train.py --source xla --search random". If not '
'given, it defaults to the largest timestamp matching '
'~/out/tpugraphs_layout/results_*_xla_random.csv')
_LAYOUT_XLA_DEFAULT_CSV = flags.DEFINE_string(
'layout_xla_default', '',
'Path to csv containing ranked indices for the collection '
'layout:xla:default. This can be generated using binary '
'"baselines/layout/layout_train.py --source xla --search default". If not '
'given, it defaults to the largest timestamp matching '
'~/out/tpugraphs_layout/results_*_xla_default.csv')
_DEFAULT_INPUTS = [
(_TILE_CSV, '~/out/tpugraphs_tiles/results_*.csv'),
(_LAYOUT_NLP_RANDOM_CSV, '~/out/tpugraphs_layout/results_*_nlp_random.csv'),
(_LAYOUT_NLP_DEFAULT_CSV,
'~/out/tpugraphs_layout/results_*_nlp_default.csv'),
(_LAYOUT_XLA_RANDOM_CSV, '~/out/tpugraphs_layout/results_*_xla_random.csv'),
(_LAYOUT_XLA_DEFAULT_CSV,
'~/out/tpugraphs_layout/results_*_xla_default.csv'),
]
def get_flag_value_or_latest(flag: flags.FlagHolder, pattern: str) -> str:
csv_path = flag.value
if csv_path:
if not os.path.exists(csv_path):
raise ValueError(f'File {csv_path} does not exist.')
return csv_path
files = list(sorted(glob.glob(pattern)))
if not files:
raise ValueError(f'No files matching pattern {pattern}.')
print(f'Using {files[-1]}')
return files[-1]
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
csv_filepaths = [get_flag_value_or_latest(flag, os.path.expanduser(pattern))
for flag, pattern in _DEFAULT_INPUTS]
output_ids = set()
out_filepath = os.path.expanduser(_OUTPUT_CSV.value)
out_dir = os.path.dirname(out_filepath)
os.makedirs(out_dir, exist_ok=True)
with open(out_filepath, 'w') as f:
f.write('ID,TopConfigs\n')
for filepath in csv_filepaths:
for record in csv.DictReader(open(filepath)):
if record['ID'] in output_ids:
raise ValueError(f'Duplicate record with ID {record["ID"]}')
output_ids.add(record['ID'])
f.write(f'{record["ID"]},{record["TopConfigs"]}\n')
print(f'\n\n Wrote {out_filepath}\n\n')
if __name__ == '__main__':
app.run(main)