Skip to content

Commit

Permalink
Starter code for training and producing CSV for layout collections.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560771182
  • Loading branch information
samihaija authored and mangpo committed Aug 28, 2023
1 parent e79b161 commit 5274760
Show file tree
Hide file tree
Showing 8 changed files with 1,174 additions and 6 deletions.
47 changes: 44 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,22 @@ Removing the last pipe (`| bash`) shows the commands for downloading the dataset

## Running Baseline Models

### Tile Size Model
This repo hosts two training pipelines: `tiles_train.py` and `layout_train.py`,
respectively, for training models on the collections `tile:xla` and
`layout:{nlp|xla}:{random|default}`. Both scripts train for some epochs then
infer predictions on the test set. By default, the trained models are saved and
alongside a `csv` file containing inference on the test set.

To combine all five inference files into one `csv`, you can run:

```sh
python combine_csvs.py
```

NOTE: The above command will look for files produced by `tiles_train.py` and
`layout_train.py`

### Model on `tile:xla` collection

#### Python environment setup with Conda

Expand Down Expand Up @@ -151,9 +166,35 @@ This script will print out per-program top-K errors for kernels in the validatio
}
```

## Layout Model
### Model on `layout:{xla|nlp}:{random|default}` collections

You may run the GST model, which is available at: https://github.com/kaidic/GST.

You may also run our baseline, by invoking:

```sh
# As a test.
python layout_train.py --epochs 10 --toy_data=True

# On xla:random
python layout_train.py --source xla --search random --epochs 10 --max_configs 1000

# On xla:default
python layout_train.py --source xla --search default --epochs 10 --max_configs 1000

# On nlp:random
python layout_train.py --source nlp --search random --epochs 10 --max_configs 1000

# On nlp:default
python layout_train.py --source nlp --search default --epochs 10 --max_configs 1000
```

NOTE: For running the NLP models, since the data is large, our trainer script
cannot fit the data into memory. The flag `--max_configs 1000` allows us to run,
by sampling only this many configurations per graph. However, you may write your
own scalable implementation, or modify ours, or run
GST: https://github.com/kaidic/GST.

Instructions to train the baseline models for the layout collection can be found at https://github.com/kaidic/GST.

## Dataset File Description

Expand Down
54 changes: 54 additions & 0 deletions layout_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2023 The tpu_graphs Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Binary for invoking the training loop.
# Usage Example
```sh
BIN='python baselines/layout/layout_train.py'
$BIN --source xla --search random --epochs 10 --max_configs 1000
$BIN --source xla --search default --epochs 10 --max_configs 1000
$BIN --source nlp --search random --epochs 10 --max_configs 1000
$BIN --source nlp --search default --epochs 10 --max_configs 1000
"""

from collections.abc import Sequence

from absl import app

from tpu_graphs.baselines.layout import train_args
from tpu_graphs.baselines.layout import train_lib


def main(unused_argv: Sequence[str]) -> None:
train_lib.train(train_args.get_args())


if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit 5274760

Please sign in to comment.