Skip to content

Commit fb67da8

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add TBE data configuration reporter to TBE forward" (#4130)
Summary: X-link: facebookresearch/FBGEMM#1211 Pull Request resolved: #4130 # Add TBE data configuration reporter to TBE forward call. The reporter reports TBE data configuration at the `SplitTableBatchedEmbeddingBagsCodegen` ***forward*** call. The output is a `TBEDataConfig` object, which is written to a JSON file(s). The configuration of its environment variables and an example of its usage is described below. ## Just Knobs for enablement - fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS is added for enablement of the reporter (https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures) - Default is set to `False`, enable this flag to enable reporter. - To enable it locally use: ``` jk canary set fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS --on --ttl 600 ``` ## Environment Variables --------------------- The Reporter relies on several environment variables to control its behavior. Below is a description of each variable: - **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL**: - **Description**: Determines the interval at which reports are generated. This is specified in terms of the number of iterations. - **Example Value**: `1` (report every iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_START**: - ***Description**: Specifies the start of the iteration range to capture reports. Default 0. - ***Example Value**: `0` (start reporting from the first iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_END**: - ***Description**: Specifies the end of the iteration range to capture reports. Use `-1` to report until the last iteration. Default -1. - ***Example Value**: `-1` (report until the last iteration) - **FBGEMM_REPORT_INPUT_PARAMS_BUCKET**: * **Description**: Specifies the name of the Manifold bucket where the report data will be saved. * **Example Value**: `tlparse_reports` - **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX**: - **Description**: Defines the path prefix where the report files will be stored. - **Example Value**: `tree/tests/` ## Example Usage ------------- Below is an example command demonstrating how to use the FBGEMM Reporter with specific environment variable settings: ``` FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2 FBGEMM_REPORT_INPUT_PARAMS_ITER_START=3 FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/ buck2 run mode/opt //deeplearning/fbgemm/fbgemm_gpu/bench:split_table_batched_embeddings -- device --iters 2 ``` **Explanation** The above setting will report `iter 3` and `iter 5` * **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2**: The reporter will generate a report every 2 iterations. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_START=0**: The reporter will start generating reports from the first iteration. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_END=-1 (Default)**: The reporter will continue to generate reports until the last iteration interval. * **FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports**: The reports will be saved in the `tlparse_reports` bucket. * **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/**: The reports will be stored with the path prefix `tree/tests/`. For Manifold make sure all folders within the path exist. **Note on Benchmark example** Note that with the `--iters 2` option, the benchmark will execute 6 forward calls (2 iterations plus 1 warmup) for the forward benchmark and another 3 calls (2 iterations plus 1 warmup) for the backward benchmark. Iteration starts from 0. --- --- ## Other includes changes in this Diff: - Updates build dependency of tbe_data_config* files - Remove `shutil` and `numpy.random` lib as it cause uncompatiblity error. - Add non-OSS test, writing extracted config data json file to Manifold Reviewed By: q10 Differential Revision: D73927918
1 parent a7246da commit fb67da8

File tree

9 files changed

+326
-66
lines changed

9 files changed

+326
-66
lines changed

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def foo():
6060
# Enable bounds_check_indices_v2
6161
BOUNDS_CHECK_INDICES_V2 = auto()
6262

63+
# Enable TBE input parameters extraction
64+
TBE_REPORT_INPUT_PARAMS = auto()
65+
6366
def is_enabled(self) -> bool:
6467
return FeatureGate.is_enabled(self)
6568

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
generate_vbe_metadata,
5252
is_torchdynamo_compiling,
5353
)
54+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
5455
from fbgemm_gpu.tbe_input_multiplexer import (
5556
TBEInfo,
5657
TBEInputInfo,
@@ -1441,6 +1442,11 @@ def __init__( # noqa C901
14411442
self._debug_print_input_stats_factory()
14421443
)
14431444

1445+
# Get a reporter function pointer
1446+
self._report_input_params: Callable[..., None] = (
1447+
self.__report_input_params_factory()
1448+
)
1449+
14441450
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
14451451
# Register writeback hook for Exact_SGD optimizer
14461452
self.log(
@@ -1952,6 +1958,18 @@ def forward( # noqa: C901
19521958
# Print input stats if enable (for debugging purpose only)
19531959
self._debug_print_input_stats(indices, offsets, per_sample_weights)
19541960

1961+
# Extract and Write input stats if enable
1962+
self._report_input_params(
1963+
feature_rows=self.rows_per_table,
1964+
feature_dims=self.feature_dims,
1965+
iteration=self.iter.item() if hasattr(self, "iter") else 0,
1966+
indices=indices,
1967+
offsets=offsets,
1968+
op_id=self.uuid,
1969+
per_sample_weights=per_sample_weights,
1970+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1971+
)
1972+
19551973
if not is_torchdynamo_compiling():
19561974
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
19571975

@@ -3792,6 +3810,36 @@ def _debug_print_input_stats_factory_null(
37923810
return _debug_print_input_stats_factory_impl
37933811
return _debug_print_input_stats_factory_null
37943812

3813+
@torch.jit.ignore
3814+
def __report_input_params_factory(self) -> Callable[..., None]:
3815+
"""
3816+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
3817+
3818+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
3819+
- Reports input parameters (TBEDataConfig).
3820+
- Writes the output as a JSON file.
3821+
3822+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
3823+
"""
3824+
3825+
@torch.jit.ignore
3826+
def __report_input_params_factory_null(
3827+
feature_rows: Tensor,
3828+
feature_dims: Tensor,
3829+
iteration: int,
3830+
indices: Tensor,
3831+
offsets: Tensor,
3832+
op_id: Optional[str] = None,
3833+
per_sample_weights: Optional[Tensor] = None,
3834+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3835+
) -> None:
3836+
pass
3837+
3838+
if FeatureGateName.TBE_REPORT_INPUT_PARAMS.is_enabled():
3839+
reporter = TBEBenchmarkParamsReporter.create()
3840+
return reporter.report_stats
3841+
return __report_input_params_factory_null
3842+
37953843

37963844
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
37973845
"""

fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
import torch
1212
import yaml
1313

14-
from .tbe_data_config import TBEDataConfig
15-
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
14+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
15+
BatchParams,
16+
IndicesParams,
17+
PoolingParams,
18+
TBEDataConfig,
19+
)
1620

1721

1822
class TBEDataConfigLoader:

fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py

Lines changed: 119 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010
import io
11+
import json
1112
import logging
1213
import os
1314
from typing import List, Optional
@@ -16,18 +17,18 @@
1617
import numpy as np # usort:skip
1718
import torch # usort:skip
1819

19-
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
20-
SplitTableBatchedEmbeddingBagsCodegen,
21-
)
22-
from fbgemm_gpu.tbe.bench import (
20+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
2321
BatchParams,
2422
IndicesParams,
2523
PoolingParams,
2624
TBEDataConfig,
2725
)
2826

29-
# pyre-ignore[16]
30-
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
27+
try:
28+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
29+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
30+
except Exception:
31+
open_source: bool = False
3132

3233
if open_source:
3334
from fbgemm_gpu.utils import FileStore
@@ -43,7 +44,8 @@ class TBEBenchmarkParamsReporter:
4344
def __init__(
4445
self,
4546
report_interval: int,
46-
report_once: bool = False,
47+
report_iter_start: int = 0,
48+
report_iter_end: int = -1,
4749
bucket: Optional[str] = None,
4850
path_prefix: Optional[str] = None,
4951
) -> None:
@@ -52,15 +54,32 @@ def __init__(
5254
5355
Args:
5456
report_interval (int): The interval at which reports are generated.
55-
report_once (bool, optional): If True, reporting occurs only once. Defaults to False.
57+
report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
58+
report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
5659
bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
5760
path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
5861
"""
62+
assert report_interval > 0, "report_interval must be greater than 0"
63+
assert (
64+
report_iter_start >= 0
65+
), "report_iter_start must be greater than or equal to 0"
66+
assert (
67+
report_iter_end >= -1
68+
), "report_iter_end must be greater than or equal to -1"
69+
assert (
70+
report_iter_end == -1 or report_iter_start <= report_iter_end
71+
), "report_iter_start must be less than or equal to report_iter_end"
72+
5973
self.report_interval = report_interval
60-
self.report_once = report_once
61-
self.has_reported = False
74+
self.report_iter_start = report_iter_start
75+
self.report_iter_end = report_iter_end
76+
self.path_prefix = path_prefix
6277

63-
default_bucket = "/tmp" if open_source else "tlparse_reports"
78+
default_bucket = (
79+
f"/tmp/{os.environ.get('USER', 'default_user')}"
80+
if open_source
81+
else "tlparse_reports"
82+
)
6483
bucket = (
6584
bucket
6685
if bucket is not None
@@ -71,19 +90,59 @@ def __init__(
7190
self.logger: logging.Logger = logging.getLogger(__name__)
7291
self.logger.setLevel(logging.INFO)
7392

93+
@classmethod
94+
def create(cls) -> "TBEBenchmarkParamsReporter":
95+
"""
96+
This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
97+
98+
If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
99+
- Reports input parameters (TBEDataConfig).
100+
- Writes the output as a JSON file.
101+
102+
Additionally, the following environment variables are considered:
103+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
104+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
105+
- `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
106+
- `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
107+
108+
Returns:
109+
TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
110+
"""
111+
report_interval = int(
112+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
113+
)
114+
report_iter_start = int(
115+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
116+
)
117+
report_iter_end = int(
118+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
119+
)
120+
bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
121+
path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
122+
123+
return cls(
124+
report_interval=report_interval,
125+
report_iter_start=report_iter_start,
126+
report_iter_end=report_iter_end,
127+
bucket=bucket,
128+
path_prefix=path_prefix,
129+
)
130+
74131
def extract_params(
75132
self,
76-
embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
133+
feature_rows: torch.Tensor,
134+
feature_dims: torch.Tensor,
77135
indices: torch.Tensor,
78136
offsets: torch.Tensor,
79137
per_sample_weights: Optional[torch.Tensor] = None,
80138
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
81139
) -> TBEDataConfig:
82140
"""
83-
Extracts parameters from the embedding operation, input indices and offsets to create a TBEDataConfig.
141+
Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
84142
85143
Args:
86-
embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
144+
feature_rows (torch.Tensor): Number of rows in each feature.
145+
feature_dims (torch.Tensor): Number of dimensions in each feature.
87146
indices (torch.Tensor): The input indices tensor.
88147
offsets (torch.Tensor): The input offsets tensor.
89148
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
@@ -92,24 +151,25 @@ def extract_params(
92151
Returns:
93152
TBEDataConfig: The configuration data for TBE benchmarking.
94153
"""
154+
155+
Es = feature_rows.tolist()
156+
Ds = feature_dims.tolist()
157+
158+
assert len(Es) == len(
159+
Ds
160+
), "feature_rows and feature_dims must have the same length"
161+
95162
# Transfer indices back to CPU for EEG analysis
96163
indices_cpu = indices.cpu()
97164

98-
# Extract embedding table specs
99-
embedding_specs = [
100-
embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map
101-
]
102-
rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs]
103-
dims = [embedding_spec[1] for embedding_spec in embedding_specs]
104-
105165
# Set T to be the number of features we are looking at
106-
T = len(embedding_op.feature_table_map)
166+
T = len(Ds)
107167
# Set E to be the mean of the rowcounts to avoid biasing
108-
E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts)))
168+
E = Es[0] if len(set(Es)) == 1 else np.ceil((np.mean(Es)))
109169
# Set mixed_dim to be True if there are multiple dims
110-
mixed_dim = len(set(dims)) > 1
170+
mixed_dim = len(set(Ds)) > 1
111171
# Set D to be the mean of the dims to avoid biasing
112-
D = dims[0] if not mixed_dim else np.ceil((np.mean(dims)))
172+
D = Ds[0] if not mixed_dim else np.ceil((np.mean(Ds)))
113173

114174
# Compute indices distribution parameters
115175
heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
@@ -160,34 +220,58 @@ def extract_params(
160220

161221
def report_stats(
162222
self,
163-
embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
223+
feature_rows: torch.Tensor,
224+
feature_dims: torch.Tensor,
225+
iteration: int,
164226
indices: torch.Tensor,
165227
offsets: torch.Tensor,
228+
op_id: str = "",
166229
per_sample_weights: Optional[torch.Tensor] = None,
167230
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
168231
) -> None:
169232
"""
170-
Reports the configuration of the embedding operation and input data then writes the TBE configuration to the filestore.
233+
Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
171234
172235
Args:
173-
embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
236+
feature_rows (torch.Tensor): Number of rows in each feature.
237+
feature_dims (torch.Tensor): Number of dimensions in each feature.
238+
iteration (int): The current iteration number.
174239
indices (torch.Tensor): The input indices tensor.
175240
offsets (torch.Tensor): The input offsets tensor.
241+
op_id (str, optional): The operation identifier. Defaults to an empty string.
176242
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
177243
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
178244
"""
179-
if embedding_op.iter.item() % self.report_interval == 0 and (
180-
not self.report_once or (self.report_once and not self.has_reported)
245+
if (
246+
(iteration - self.report_iter_start) % self.report_interval == 0
247+
and (iteration >= self.report_iter_start)
248+
and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
181249
):
182250
# Extract TBE config
183251
config = self.extract_params(
184-
embedding_op, indices, offsets, per_sample_weights
252+
feature_rows=feature_rows,
253+
feature_dims=feature_dims,
254+
indices=indices,
255+
offsets=offsets,
256+
per_sample_weights=per_sample_weights,
257+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
185258
)
186259

260+
config.json()
261+
262+
# Ad-hoc fix for adding Es and Ds to JSON output
263+
# TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig
264+
adhoc_config = config.dict()
265+
adhoc_config["Es"] = feature_rows.tolist()
266+
adhoc_config["Ds"] = feature_dims.tolist()
267+
if batch_size_per_feature_per_rank:
268+
adhoc_config["Bs"] = [
269+
sum(batch_size_per_feature_per_rank[f])
270+
for f in range(len(adhoc_config["Es"]))
271+
]
272+
187273
# Write the TBE config to FileStore
188274
self.filestore.write(
189-
f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json",
190-
io.BytesIO(config.json(format=True).encode()),
275+
f"{self.path_prefix}tbe-{op_id}-config-estimation-{iteration}.json",
276+
io.BytesIO(json.dumps(adhoc_config, indent=2).encode()),
191277
)
192-
193-
self.has_reported = True

fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
import numpy.typing as npt
1515
import torch
1616

17-
# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
18-
from numpy.random import default_rng
19-
2017
from .common import get_device
2118
from .offsets import get_table_batched_offsets_from_dense
2219

@@ -309,11 +306,9 @@ def generate_indices_zipf(
309306
indices, torch.tensor([0, L], dtype=torch.long), True
310307
)
311308
if deterministic_output:
312-
rng = default_rng(12345)
313-
else:
314-
rng = default_rng()
309+
np.random.seed(12345)
315310
permutation = torch.as_tensor(
316-
rng.choice(E, size=indices.max().item() + 1, replace=False)
311+
np.random.choice(E, size=indices.max().item() + 1, replace=False)
317312
)
318313
indices = permutation.gather(0, indices.flatten())
319314
indices = indices.to(get_device()).int()

fbgemm_gpu/fbgemm_gpu/utils/filestore.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import io
1212
import logging
1313
import os
14-
import shutil
1514
from dataclasses import dataclass
1615
from pathlib import Path
1716
from typing import BinaryIO, Union
@@ -76,7 +75,12 @@ def write(
7675
elif isinstance(raw_input, Path):
7776
if not os.path.exists(raw_input):
7877
raise FileNotFoundError(f"File {raw_input} does not exist")
79-
shutil.copyfile(raw_input, filepath)
78+
# Open the source file and destination file, and copy the contents
79+
with open(raw_input, "rb") as src_file, open(
80+
filepath, "wb"
81+
) as dst_file:
82+
while chunk := src_file.read(4096): # Read 4 KB at a time
83+
dst_file.write(chunk)
8084

8185
elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
8286
with open(filepath, "wb") as file:

0 commit comments

Comments
 (0)