forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bazel-sharding.py
114 lines (90 loc) · 3.66 KB
/
bazel-sharding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
#
# Copyright 2021 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# FROM https://github.com/philwo/bazel-utils/blob/main/sharding/sharding.py
import argparse
import os
import subprocess
import sys
def partition_targets(targets):
included_targets, excluded_targets = [], []
for target in targets:
if target.startswith("-"):
excluded_targets.append(target[1:])
else:
included_targets.append(target)
return included_targets, excluded_targets
def quote_targets(targets):
return (" ".join("'{}'".format(t) for t in targets)) if targets else ""
def get_target_expansion_query(targets, tests_only, exclude_manual):
included_targets, excluded_targets = partition_targets(targets)
included_targets = quote_targets(included_targets)
excluded_targets = quote_targets(excluded_targets)
query = "set({})".format(included_targets)
if tests_only:
query = "tests({})".format(query)
if excluded_targets:
excluded_set = "set({})".format(excluded_targets)
if tests_only:
excluded_set = "tests({})".format(excluded_set)
query = "{} except {}".format(query, excluded_set)
if exclude_manual:
query = '{} except tests(attr("tags", "manual", set({})))'.format(
query, included_targets)
return query
def run_bazel_query(query, debug):
args = ["bazel", "query", query]
if debug:
print("$ {}".format(" ".join(args)), file=sys.stderr)
sys.stderr.flush()
p = subprocess.run(
["bazel", "query", query],
check=True,
stdout=subprocess.PIPE,
errors="replace",
universal_newlines=True,
)
output = p.stdout.strip()
return output.splitlines() if output else []
def get_targets_for_shard(targets, index, count):
# This is a very simple way of sharding targets. A more sophisticated
# approach might want to take test sizes into account, for example.
return sorted(targets)[index::count]
def main():
parser = argparse.ArgumentParser(
description="Expand and shard Bazel targets.")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--tests_only", action="store_true")
parser.add_argument("--exclude_manual", action="store_true")
parser.add_argument(
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1))
parser.add_argument(
"--count",
type=int,
default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1))
parser.add_argument("targets", nargs="+")
args, extra_args = parser.parse_known_args()
args.targets = list(args.targets) + list(extra_args)
if args.index >= args.count:
parser.error("--index must be between 0 and {}".format(args.count - 1))
query = get_target_expansion_query(args.targets, args.tests_only,
args.exclude_manual)
expanded_targets = run_bazel_query(query, args.debug)
my_targets = get_targets_for_shard(expanded_targets, args.index,
args.count)
print(" ".join(my_targets))
return 0
if __name__ == "__main__":
sys.exit(main())