-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make get_iree_devices read IREE_DEVICE env var if provided #891
base: main
Are you sure you want to change the base?
Conversation
ed78ac0
to
4546492
Compare
This allows to inject what exact IREE device(s) are to be used without propagating all the way to program arguments.
4546492
to
5a28d36
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this change is specifically targeting CIs, our runners pick whichever gpu is available atm. We had a PR making this change for llama benchmark/ppl CIs. @ScottTodd can confirm.
It is not just for CIs, but when running locally as well. There should be some way of overriding the device list. |
I believe on servers with multiple GPUs, we have one runner per GPU, using That's my understanding anyways. @yamiyysu and @saienduri would know more, and we should write that all down somewhere and get the code used to configure the runners checked in too. |
def get_iree_devices( | ||
*, driver: str | None = None, device_count: int = 1 | ||
) -> List[iree.runtime.HalDevice]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a docstring to this function and mention the IREE_DEVICE
environment variable with a few examples? You can lift some text from the PR description
This allows to inject what exact IREE device(s) are to be used without propagating all the way to program arguments. Example:
IREE_DEVICE=hip://5,hip://3 pytest sharktank/tests
def get_iree_devices(
*, driver: str | None = None, device_count: int = 1
) -> List[iree.runtime.HalDevice]:
"""Gets a list of IREE HAL devices for the given driver.
The first available device_count devices will be created,
unless the IREE_DEVICE environment variable is set to an
explicit list of device URIs.
For example, to select HIP devices 5 and 3:
```
export IREE_DEVICE=hip://5,hip://3
python ... device=hip # (something here)
```
"""
(could wordsmith that a bit, just getting the general idea across)
@@ -22,7 +23,28 @@ | |||
from .tree import Tree | |||
|
|||
|
|||
def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: | |||
def get_iree_devices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May also want to put pytest details in https://github.com/nod-ai/shark-ai/blob/main/docs/developer_guide.md#running-tests
This allows to inject what exact IREE device(s) are to be used without propagating all the way to program arguments.
Example: