Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanded sharded support for alternative sharding mechanisms #680

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

rsuderman
Copy link
Contributor

Single-logical-multi-physical sharding allows tensor access between
different devices and tighter synchronization on execution. This means
that sharding needs to support more than differing device ordinals but
also configre multiple queues for the same device. Sharded tensor types
are reworked to support tracking both the supported device AND the queue
it is enqueued on.

To support this each sharded tensor now tracks the DeviceAffinity it is
associated with, along with reassigning affinities post construction.
This allows pre-sharded models to have their affinities updated with an
alternative transfer mechanism.

If device affinity is not specified the default arrangement assumes
separate device ordinals for each shard.

Single-logical-multi-physical sharding allows tensor access between
different devices and tighter synchronization on execution. This means
that sharding needs to support more than differing device ordinals but
also configre multiple queues for the same device. Sharded tensor types
are reworked to support tracking both the supported device AND the queue
it is enqueued on.

To support this each sharded tensor now tracks the DeviceAffinity it is
associated with, along with reassigning affinities post construction.
This allows pre-sharded models to have their affinities updated with an
alternative transfer mechanism.

If device affinity is not specified the default arrangement assumes
separate device ordinals for each shard.
@@ -279,16 +283,15 @@ def main():
tensor_parallelism_size=args.tensor_parallelism_size,
fake_quant=args.fake_quant,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove shard_theta import if unused.

@stbaione
Copy link
Contributor

I saw the following error this morning when attempting to validate toy_llama_tp2 from iree-test-suites by exporting, and compiling with intent to then verify with iree-run-module, just to make sure patch worked. Using shortfin/sharktank w/ locally built IREE at HEAD:

  1. Obtain assets from iree-test-suites, specifically toy_llama_tp2.irpa, toy_llama_tp2.rank0.irpa, toy_llama_tp2.rank1.irpa
  2. Export to MLIR:
python -m sharktank.examples.export_paged_llm_v1 --bs=1  --irpa-file assets/toy_llama_tp2.irpa --output-mlir=llama.mlir --output-config=config.json --use-queue-affinities
  1. Attempt to compile to vmfb. Started with compiling sharded llama for single device for simplest validation:
iree-compile llama.mlir -o llama.vmfb --iree-hip-target=gfx942 --iree-hal-target-device=hip[0]

Received the following error:

/toy_new/llama.mlir:4027:12: error: op affinity #hal.device.affinity<@__device_0> is not compatible with the partition affinity #hal.device.affinity<@__device_0, [0]>
    %153 = torch.prims.convert_element_type %1, %int5_87 : !torch.vtensor<[256,256],f32>, !torch.int -> !torch.vtensor<[256,256],f16>
           ^
./toy_new/llama.mlir:4027:12: note: see current operation: %190 = "stream.async.transfer"(%189, %10, %10) <{result_affinity = #hal.device.affinity<@__device_0>, source_affinity = #hal.device.affinity<@__device_0, [1]>}> : (!stream.resource<constant>, index, index) -> !stream.resource<constant>

Feedback from Rob this morning before sync:

Hmmm, see if you can figure out where the wrong affinity is. 
Looks like something is not placed correctly. 
Given its an async transfer I would guess we need to strip the transfers in sharded_impls.py

self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
self.devices = devices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant with L34, can be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants