From 3b0d0c08da6f3cbc9afcc1b6edae965615a1539d Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Thu, 12 Sep 2024 17:31:53 +0000 Subject: [PATCH 01/18] make profiler work for GPU --- MaxText/profiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MaxText/profiler.py b/MaxText/profiler.py index faee511c5..49d230a14 100644 --- a/MaxText/profiler.py +++ b/MaxText/profiler.py @@ -47,6 +47,9 @@ def activate(self): return self.libcudart.cudaProfilerStart() elif self.mode == "xplane": + if self.output_path.startswith("gs://"): + self.output_path = "/scratch/jwyang-workspace/jax-profiles/" + self.output_path[4:] + print("Start profiling {}".format(self.output_path)) jax.profiler.start_trace(self.output_path) def deactivate(self): From 45bbf75ca82db900b5e9c31e9bb8c8c62254977b Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Thu, 31 Oct 2024 17:05:25 +0000 Subject: [PATCH 02/18] 1. add aqt fp8 quantization 2. Hacky transpose in MLP output layer to get rid of the slow matmul for llama2 70b model. --- MaxText/layers/linears.py | 9 ++++++++- MaxText/layers/quantizations.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index c0028a9a5..126697e12 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -115,7 +115,14 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): if self.quant: dot_general_cls = self.quant.dot_general_cls(mesh_axes=self.kernel_axes) dot_general = dot_general_cls() - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + print("shape info:", self.name, inputs.shape, kernel.shape, axis, contract_ind) + if self.name == "wo": + inputs = jnp.transpose(inputs, axes=[2, 1, 0]) + axis = 0 + output = dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + if self.name == "wo": + output = jnp.transpose(output, [1, 0, 2]) + return output features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 26842db3a..b9160b85f 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -190,6 +190,8 @@ def _get_int8_quant_config(config): drhs_accumulator_dtype=drhs_accumulator_dtype, ) +def _get_fp8_quant_config(config): + return aqt_config.config_fwd_fp8() def _get_weight_only_quant_config(lhs_bits=None, rhs_bits=None): return aqt_config.dot_general_make(lhs_bits=lhs_bits, rhs_bits=rhs_bits) @@ -219,6 +221,8 @@ def _get_quant_config(config): return None if config.quantization == "int8": return _get_int8_quant_config(config) + if config.quantization == "aqt_fp8": + return _get_fp8_quant_config(config) if config.quantization == "int8w": return _get_weight_only_quant_config(lhs_bits=None, rhs_bits=8) if config.quantization == "int4w": From dfb02bc8e6de158a11bfce16239a8bd51a8c391e Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Tue, 5 Nov 2024 18:36:50 +0000 Subject: [PATCH 03/18] fp8 quantization for inference --- MaxText/layers/linears.py | 2 +- MaxText/layers/quantizations.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 126697e12..7d62db602 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -118,7 +118,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): print("shape info:", self.name, inputs.shape, kernel.shape, axis, contract_ind) if self.name == "wo": inputs = jnp.transpose(inputs, axes=[2, 1, 0]) - axis = 0 + axis = (0,) output = dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) if self.name == "wo": output = jnp.transpose(output, [1, 0, 2]) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index b9160b85f..4839d1f62 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -190,7 +190,7 @@ def _get_int8_quant_config(config): drhs_accumulator_dtype=drhs_accumulator_dtype, ) -def _get_fp8_quant_config(config): +def _get_aqt_fp8_quant_config(config): return aqt_config.config_fwd_fp8() def _get_weight_only_quant_config(lhs_bits=None, rhs_bits=None): @@ -222,7 +222,7 @@ def _get_quant_config(config): if config.quantization == "int8": return _get_int8_quant_config(config) if config.quantization == "aqt_fp8": - return _get_fp8_quant_config(config) + return _get_aqt_fp8_quant_config(config) if config.quantization == "int8w": return _get_weight_only_quant_config(lhs_bits=None, rhs_bits=8) if config.quantization == "int4w": From 66f0effd94683b007ff69edbbfa4c2e73f03c834 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Tue, 5 Nov 2024 22:53:24 +0000 Subject: [PATCH 04/18] kv cache fp8 quantization. --- MaxText/layers/linears.py | 2 +- MaxText/layers/quantizations.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 7d62db602..7e0fe3b5b 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -115,7 +115,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): if self.quant: dot_general_cls = self.quant.dot_general_cls(mesh_axes=self.kernel_axes) dot_general = dot_general_cls() - print("shape info:", self.name, inputs.shape, kernel.shape, axis, contract_ind) + # print("shape info:", self.name, inputs.shape, kernel.shape, axis, contract_ind) if self.name == "wo": inputs = jnp.transpose(inputs, axes=[2, 1, 0]) axis = (0,) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 4839d1f62..6686f6e93 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -33,6 +33,7 @@ MAX_INT8 = 127.5 MAX_INT4 = 7.5 +E4M3_MAX = jnp.finfo(e4m3).max.astype(f32) Array = common_types.Array Config = common_types.Config @@ -311,6 +312,8 @@ def _get_dtype(self, dtype_cfg: str): return jnp.int4 if dtype_cfg == "int8": return jnp.int8 + if dtype_cfg == "fp8": + return jnp.float8_e4m3fn raise ValueError(f"Invalid kv_quant_dtype: {dtype_cfg}") def _get_max_axis(self, axis_names: AxisNames): @@ -334,6 +337,9 @@ def quantize(self, kv: Array, axis_names: AxisNames): if self.dtype == jnp.int4: value = jnp.int4(jnp.rint(kv * (MAX_INT4 / scale))) return value, scale + if self.dtype == jnp.float8_e4m3fn: + value = jnp.float8_e4m3fn(kv * (E4M3_MAX / scale)) + return value, scale raise ValueError(f"Invalid KV quant dtype:{self.dtype}.") def einsum_fn_with_rhs_qtensor( From f9983e5b6be83bcaec0eff984d989d105a72fa2e Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Tue, 5 Nov 2024 23:17:37 +0000 Subject: [PATCH 05/18] kvcache fp8 quantization. --- MaxText/layers/quantizations.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 6686f6e93..bfcc0d829 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -33,7 +33,7 @@ MAX_INT8 = 127.5 MAX_INT4 = 7.5 -E4M3_MAX = jnp.finfo(e4m3).max.astype(f32) +E4M3_MAX = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) Array = common_types.Array Config = common_types.Config @@ -351,10 +351,14 @@ def einsum_fn_with_rhs_qtensor( # Assumes kv is already quantized. einsum = jnp.einsum if isinstance(kv, aqt_tensor.QTensor): - num_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8 + lhs_bits = None + rhs_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8 + if kv.qvalue == jnp.float8_e4m3fn: + lhs_bits = 'e4m3' + rhs_bits = 'e4m3' kv_cfg = aqt_config.dot_general_make( - lhs_bits=None, - rhs_bits=num_bits, + lhs_bits=lhs_bits, + rhs_bits=rhs_bits, bwd_bits=None, use_fwd_quant=False, ) From 1380993f9a36a5dd3626cb4dc67787ddd37aab0b Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Tue, 5 Nov 2024 23:31:34 +0000 Subject: [PATCH 06/18] kvcache fp8 quantization. --- MaxText/layers/quantizations.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index bfcc0d829..f8ce3c98e 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -351,17 +351,18 @@ def einsum_fn_with_rhs_qtensor( # Assumes kv is already quantized. einsum = jnp.einsum if isinstance(kv, aqt_tensor.QTensor): - lhs_bits = None - rhs_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8 - if kv.qvalue == jnp.float8_e4m3fn: - lhs_bits = 'e4m3' - rhs_bits = 'e4m3' - kv_cfg = aqt_config.dot_general_make( - lhs_bits=lhs_bits, - rhs_bits=rhs_bits, - bwd_bits=None, - use_fwd_quant=False, - ) + if kv.qvalue.dtype != jnp.float8_e4m3fn: + lhs_bits = None + rhs_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8 + kv_cfg = aqt_config.dot_general_make( + lhs_bits=lhs_bits, + rhs_bits=rhs_bits, + bwd_bits=None, + use_fwd_quant=False, + ) + else: + kv_cfg = aqt_config.config_fwd_fp8() + if rhs_dequant_mode: aqt_config.set_fwd_dequant_mode( kv_cfg, rhs_dequant_mode=rhs_dequant_mode @@ -370,13 +371,14 @@ def einsum_fn_with_rhs_qtensor( aqt_config.set_fwd_calibration_mode( kv_cfg, rhs_calibration_mode=rhs_calibration_mode, - ) + ) einsum = aqt_flax.AqtEinsum( rhs_quant_mode=aqt_flax.QuantMode.TRAIN, lhs_freeze_mode=aqt_flax.FreezerMode.NONE, rhs_freeze_mode=aqt_flax.FreezerMode.NONE, cfg=kv_cfg ) + return einsum def einsum_fn_with_rhs_qtensor_and_dequant(self, value): From ee7328caafd9109e55dd9101352fbc3a800a3ac3 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Thu, 21 Nov 2024 22:15:57 +0000 Subject: [PATCH 07/18] Simple KV cache quantization. --- MaxText/configs/base.yml | 4 ++++ MaxText/layers/attentions.py | 15 ++++++++++++--- MaxText/layers/quantizations.py | 9 ++++++++- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 536372600..96baf21bd 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -210,6 +210,10 @@ logical_axis_rules: [ ['cache_heads', ['autoregressive', 'tensor']], ['cache_kv', []], ['cache_sequence', []], + ['cache_scale_batch', []], + ['cache_scale_heads', []], + ['cache_scale_kv', []], + ['cache_scale_sequence', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index e3ae1d856..124a6d0df 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -419,7 +419,9 @@ def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_m """ einsum = jnp.einsum if self.kv_quant: - einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) + # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) + if isinstance(key, KVTensor): + key = self.kv_quant.dequant(key) b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads @@ -459,9 +461,14 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s einsum = jnp.einsum if self.kv_quant: - einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) + # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) + if isinstance(value, KVTensor): + value = self.kv_quant.dequant(value) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): - out = einsum("bkgts,bskd->btkgd", attn_weights, value) + if self.kv_quant: + out = self.kv_quant.einsum_quantize_fp8("bkgts,bskd->btkgd", attn_weights, value) + else: + out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) elif self.compute_axis_order == (0,2,1,3): @@ -730,6 +737,8 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A scale_value /= quantizations.MAX_INT8 elif dtype == jnp.int4: scale_value /= quantizations.MAX_INT4 + elif dtype == jnp.float8_e4m3fn: + scale_value /= quantizations.E4M3_MAX cache_value = KVTensor( qvalue=cache_value, diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index f8ce3c98e..d21228ff5 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -342,6 +342,13 @@ def quantize(self, kv: Array, axis_names: AxisNames): return value, scale raise ValueError(f"Invalid KV quant dtype:{self.dtype}.") + def dequant(self, kv: KVTensor): + assert isinstance(kv, KVTensor) + qvalue = kv.qvalue + scale = kv.scale[0] + dequant_kv = jnp.bfloat16(qvalue) * jnp.bfloat16(scale) + return dequant_kv + def einsum_fn_with_rhs_qtensor( self, kv: Array| aqt_tensor.QTensor, @@ -386,4 +393,4 @@ def einsum_fn_with_rhs_qtensor_and_dequant(self, value): value, rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT, rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS - ) + ) \ No newline at end of file From 8f6502400e550605633200b2fc76103800c495d4 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Thu, 21 Nov 2024 22:18:29 +0000 Subject: [PATCH 08/18] minor fix. --- MaxText/layers/attentions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 124a6d0df..725cd4efc 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -465,10 +465,7 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s if isinstance(value, KVTensor): value = self.kv_quant.dequant(value) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): - if self.kv_quant: - out = self.kv_quant.einsum_quantize_fp8("bkgts,bskd->btkgd", attn_weights, value) - else: - out = einsum("bkgts,bskd->btkgd", attn_weights, value) + out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) elif self.compute_axis_order == (0,2,1,3): From e53257db38f5d6df2d01bdecaeef1e0aacd34a74 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Fri, 22 Nov 2024 23:48:36 +0000 Subject: [PATCH 09/18] test if kv cache shape transpose works. --- MaxText/layers/attentions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 725cd4efc..826dd698d 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -422,6 +422,7 @@ def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_m # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) if isinstance(key, KVTensor): key = self.kv_quant.dequant(key) + key = self.reverse_transepose(key, self.ar_cache_axis_order) b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads @@ -462,8 +463,9 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s einsum = jnp.einsum if self.kv_quant: # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) - if isinstance(value, KVTensor): - value = self.kv_quant.dequant(value) + if isinstance(value, KVTensor): + value = self.kv_quant.dequant(value) + value = self.reverse_transepose(value, self.ar_cache_axis_order) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape @@ -743,8 +745,9 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A scale_t=None, dequant_dtype=target_dtype ) - cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value) - return cache_value_in_logical_shape + # cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value) + # return cache_value_in_logical_shape + return cache_value def kv_cache_autoregressive( self, From c5e5cb4bdd9536d1e3974b876325a9598a55f017 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Mon, 25 Nov 2024 19:45:32 +0000 Subject: [PATCH 10/18] reverse previous commit --- MaxText/layers/attentions.py | 7 +- vm_scipts/gpu_script.sh | 15 ++ vm_scipts/jax_script.py | 34 ++++ vm_scipts/tpu_script.sh | 382 +++++++++++++++++++++++++++++++++++ vm_scipts/xla_auto_tuning.sh | 54 +++++ 5 files changed, 487 insertions(+), 5 deletions(-) create mode 100644 vm_scipts/gpu_script.sh create mode 100644 vm_scipts/jax_script.py create mode 100644 vm_scipts/tpu_script.sh create mode 100644 vm_scipts/xla_auto_tuning.sh diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 826dd698d..4b202a38c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -422,7 +422,6 @@ def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_m # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) if isinstance(key, KVTensor): key = self.kv_quant.dequant(key) - key = self.reverse_transepose(key, self.ar_cache_axis_order) b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads @@ -465,7 +464,6 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) if isinstance(value, KVTensor): value = self.kv_quant.dequant(value) - value = self.reverse_transepose(value, self.ar_cache_axis_order) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape @@ -745,9 +743,8 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A scale_t=None, dequant_dtype=target_dtype ) - # cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value) - # return cache_value_in_logical_shape - return cache_value + cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value) + return cache_value_in_logical_shape def kv_cache_autoregressive( self, diff --git a/vm_scipts/gpu_script.sh b/vm_scipts/gpu_script.sh new file mode 100644 index 000000000..8fcd0b5b8 --- /dev/null +++ b/vm_scipts/gpu_script.sh @@ -0,0 +1,15 @@ +GPU_NAME="vipannalla-mlperf-v41-a3-may" +ZONE="us-central1-a" +PROJECT="cloud-tpu-inference-test" + + +ssh_to_gpu() { + gcloud compute ssh --zone ${ZONE} ${GPU_NAME} --project ${PROJECT} -- -o ProxyCommand='corp-ssh-helper %h %p' +} + +copy_maxtext_files() { + gcloud compute scp --zone ${ZONE} --project ${PROJECT} \ + $PWD/MaxText/profiler.py \ + ${GPU_NAME}:/scratch/jwyang-workspace/maxtext/MaxText/ \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" +} \ No newline at end of file diff --git a/vm_scipts/jax_script.py b/vm_scipts/jax_script.py new file mode 100644 index 000000000..83efd7738 --- /dev/null +++ b/vm_scipts/jax_script.py @@ -0,0 +1,34 @@ + + +import jax +import jax.numpy as jnp +import time + +import torch + +@jax.jit +def multiply_fusion_test(): + x = jnp.ones((1024, 1, 1280, 128), dtype=jnp.float8_e4m3fn) + y = jnp.ones((1024, 1, 1280, 128), dtype=jnp.bfloat16) + out = jnp.bfloat16(x) * y + return out + +def torch_fusion_test(): + x = torch.ones((1024, 1, 1280, 128), dtype=torch.float8_e4m3fn) + y = torch.ones((1024, 1, 1280, 128), dtype=torch.bfloat16) + out = x.to(torch.bfloat16) * y + return out + +multiply_fusion_test() # warmup run to compile +start_time = time.time_ns() +for i in range(100): + multiply_fusion_test() +end_time = time.time_ns() +print("number of ms taken for jax: ", (end_time - start_time)/(100 * 1000000)) + +assert torch.cuda.is_available() +start_time = time.time_ns() +for i in range(100): + torch_fusion_test() +end_time = time.time_ns() +print("number of ms taken for pytorch: ", (end_time - start_time)/(100 * 1000000)) \ No newline at end of file diff --git a/vm_scipts/tpu_script.sh b/vm_scipts/tpu_script.sh new file mode 100644 index 000000000..c6635652f --- /dev/null +++ b/vm_scipts/tpu_script.sh @@ -0,0 +1,382 @@ +#!/bin/bash + + +#!/bin/bash + +# Multi-Host vlp (TODO: replace these params for your own config) +NAME="jwyang-tpu-sh1" +# NAME="jwyang-v5p8-vm" +ACCELERATOR_TYPE="v5litepod-4" +# ACCELERATOR_TYPE="v5litepod-8" +# ACCELERATOR_TYPE="v5p-8" +RUNTIME_VERSION="v2-alpha-tpuv5-lite" +# PROJECT="tpu-prod-env-automated" +PROJECT="cloud-tpu-inference-test" +# PROJECT="tpu-prod-env-small" +# PROJECT="tpu-prod-env-large-cont" +# ZONE="us-east1-c" +ZONE="us-west1-c" +# ZONE="us-east5-a" + +USER=jwyang + +# (TODO: replace these params to your own config) +NUM_WORKERS=1 +TPU_NAME="t1v-n-63d3a09c" + +create_tpu() { + # A temporary solution to clean up the failed and suspended queued resources. + # Otherwise, there will be a quota error. + existing_qr=$(gcloud alpha compute tpus queued-resources list \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --quiet) + while read -r line; do + name=$(echo $line | awk '{print $1}') + status=$(echo $line | awk '{print $5}') + echo ${name} + echo ${status} + if [[ ${status} == "SUSPENDED" || ${status} == "FAILED" ]]; then + gcloud alpha compute tpus queued-resources delete ${name} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --quiet + fi + done <<< ${existing_qr} + + gcloud alpha compute tpus queued-resources create ${NAME} \ + --description noteardown \ + --node-id ${NAME} \ + --project=${PROJECT} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --runtime-version=${RUNTIME_VERSION} \ + --reserved; +} + +list_tpu() { + gcloud compute tpus tpu-vm list --project=${PROJECT} --zone=${ZONE}; +} + +list_queue_resource() { + gcloud alpha compute tpus queued-resources list --project=${PROJECT} --zone=${ZONE}; +} + +delete_tpu() { + gcloud alpha compute tpus tpu-vm delete ${NAME} --project=${PROJECT} --zone=${ZONE}; + gcloud alpha compute tpus queued-resources delete ${NAME} --project=${PROJECT} --zone=${ZONE}; +} + +ssh_to_tpu() { + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${1} --project ${PROJECT} -- -o ProxyCommand='corp-ssh-helper %h %p' +} + +create_disk() { + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + TPU_WORKER_NAME=${TPU_NAME}-w-${i} + DISK_NAME=${NAME}-w${i}-ssd + + SIZE=35 + if [[ ${i} == 0 ]] + then + SIZE=512 + fi + + gcloud compute disks create ${DISK_NAME} \ + --size ${SIZE} \ + --zone ${ZONE} \ + --type pd-ssd \ + --project=${PROJECT} + + # attach disk to tpu + gcloud alpha compute instances attach-disk ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --disk=${DISK_NAME} \ + --mode=rw \ + --project=${PROJECT} + + gcloud compute instances set-disk-auto-delete ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --auto-delete \ + --disk=${DISK_NAME} \ + --project=${PROJECT} + + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ + --command="sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb && + sudo mkdir -p /mnt/disks/persist && + sudo mount -o discard,defaults /dev/sdb /mnt/disks/persist" \ + -- -o ProxyCommand='corp-ssh-helper %h %p' + done +} + +detach_disks() { + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + TPU_WORKER_NAME=${TPU_NAME}-w-${i} + DISK_NAME=${NAME}-w${i}-ssd + + # attach disk to tpu + gcloud alpha compute instances detach-disk ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --disk=${DISK_NAME} \ + --project=${PROJECT} + done +} + +check_disks() { + set -o xtrace + dir_checks="" + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + dir_checks="$dir_checks $( + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ + --command="if [ -d /mnt/disks/persist ]; then echo "exists"; fi" \ + -- -o ProxyCommand='corp-ssh-helper %h %p' + )" + done + num_dir_exists=$(echo "$dir_checks" | wc -w) + echo "Number of workers with disks: $num_dir_exists" + set +o xtrace +} + + + +############### Scrach ################ +copy_relevant_files() { + # # kill model server process + # gcloud compute tpus tpu-vm ssh ${NAME} --zone=${ZONE} --worker=all --project=${PROJECT} \ + # --command="sudo rm /tmp/libtpu_lockfile && sudo lsof -t /dev/vfio/0 > tpu_process_pid.txt && sudo pkill -F tpu_process_pid.txt" \ + # -- -o ProxyCommand='corp-ssh-helper %h %p' + + + # gcloud compute tpus tpu-vm \ + # scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + # $PWD/MaxText/maxengine.py \ + # ${NAME}:~/maxtext/MaxText/maxengine.py \ + # --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + + gcloud compute tpus tpu-vm \ + scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + $PWD/benchmarks/benchmark_serving.py \ + ${NAME}:~/JetStream/benchmarks/benchmark_serving.py \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + + gcloud compute tpus tpu-vm \ + scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + $PWD/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + ${NAME}:~/JetStream/benchmarks/ \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + +} + + +# # Microbenchmark command +# # source .env/bin/activate +# your_run_name=jwyang_bs1_llama7b +# python MaxText/inference_microbenchmark.py \ +# MaxText/configs/base.yml \ +# base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ +# run_name=${your_run_name} \ +# per_device_batch_size=12 \ +# save_config_to_gcs=true \ +# model_name=llama2-7b \ +# tokenizer_path=assets/tokenizer.llama2 \ +# inference_microbenchmark_prefill_lengths=1024 \ +# max_prefill_predict_length=1024 \ +# max_target_length=2048 \ +# ici_fsdp_parallelism=1 \ +# ici_tensor_parallelism=-1 \ +# ici_autoregressive_parallelism=1 \ +# weight_dtype=bfloat16 \ +# enable_profiler=true \ +# scan_layers=false \ +# quantization=int8 \ +# quantize_kvcache=true +# inference_mode=true + + +# LLaMA2-7B JetStream/Maxtext commands +export model_name=llama2-7b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=12 +export load_parameters_path_chat=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-18-28/checkpoints/0/items +export load_parameters_path=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-19-40/checkpoints/0/items + +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ + load_parameters_path=${load_parameters_path_chat} \ + run_name=$(date +%Y-%m-%d-%H-%M) \ + save_config_to_gcs=true \ + model_name=${model_name} \ + tokenizer_path=${tokenizer_path} \ + inference_microbenchmark_log_file_path=microbenchmark.json \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=1000 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + per_device_batch_size=${per_device_batch_size} \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + enable_profiler=false \ + scan_layers=false \ + weight_dtype=bfloat16 + # quantization=int8 + # quantize_kvcache=True + + +export model_name=llama2-7b +export dataset_path=/home/jwyang/llama7b_chat_openorca_input.json +python JetStream/benchmarks/benchmark_serving.py \ + --tokenizer ~/maxtext/assets/tokenizer.llama2 \ + --warmup-first true \ + --save-result \ + --save-request-outputs \ + --request-outputs-file-path /home/jwyang/outputs.json \ + --num-prompts 1000 \ + --max-output-length 1024 \ + --dataset openorca \ + --dataset-path ${dataset_path} + + + +# # 13b model +export model_name=llama2-13b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=1 +export load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items + + +export experiment_time=$(date +%Y-%m-%d-%H-%M) +echo "export experiment_time=${experiment_time}" +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ + model_name=llama2-13b \ + async_checkpointing=false \ + load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ + run_name=${experiment_time} \ + inference_microbenchmark_log_file_path=${run_name}.json \ + tokenizer_path=assets/tokenizer.llama2 \ + weight_dtype=bfloat16 \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=10 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=-1 \ + ici_autoregressive_parallelism=1 \ + enable_profiler=false \ + scan_layers=false \ + attention=dot_product \ + save_config_to_gcs=true \ + per_device_batch_size=1 + + + run_name=$(date +%Y-%m-%d-%H-%M) \ + save_config_to_gcs=true \ + model_name=${model_name} \ + tokenizer_path=${tokenizer_path} \ + inference_microbenchmark_log_file_path=microbenchmark.json \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=1000 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + per_device_batch_size=${per_device_batch_size} \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + enable_profiler=false \ + scan_layers=false \ + weight_dtype=bfloat16 + + +python MaxText/inference_microbenchmark.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ + model_name=llama2-13b \ + async_checkpointing=false \ + load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ + run_name=${experiment_time} \ + inference_microbenchmark_log_file_path=${run_name}.json \ + tokenizer_path=assets/tokenizer.llama2 \ + weight_dtype=bfloat16 \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=10 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=-1 \ + ici_autoregressive_parallelism=1 \ + enable_profiler=false \ + scan_layers=false \ + attention=dot_product \ + save_config_to_gcs=true \ + per_device_batch_size=1 + + + +# # LLaMA2-70B commands +# # source .env/bin/activate +# your_run_name=jwyang_bs1_llama70b +# python MaxText/inference_microbenchmark.py \ +# MaxText/configs/base.yml \ +# base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ +# run_name=${your_run_name} \ +# per_device_batch_size=1 \ +# save_config_to_gcs=true \ +# model_name=llama2-70b \ +# tokenizer_path=assets/tokenizer.llama2 \ +# inference_microbenchmark_prefill_lengths=32 \ +# max_prefill_predict_length=32 \ +# max_target_length=64 \ +# ici_fsdp_parallelism=1 \ +# ici_tensor_parallelism=-1 \ +# ici_autoregressive_parallelism=1 \ +# weight_dtype=bfloat16 \ +# enable_profiler=true \ +# scan_layers=false \ +# quantization=int8 \ +# quantize_kvcache=true + + +export model_name=llama2-70b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=1 +export prefill_length=16 +export target_length=32 + +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ + run_name=$(date +%Y-%m-%d-%H-%M) \ + save_config_to_gcs=true \ + model_name=${model_name} \ + tokenizer_path=${tokenizer_path} \ + inference_microbenchmark_log_file_path=microbenchmark.json \ + inference_microbenchmark_prefill_lengths=${prefill_length} \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=1000 \ + max_prefill_predict_length=${prefill_length} \ + max_target_length=${target_length} \ + per_device_batch_size=${per_device_batch_size} \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + enable_profiler=false \ + scan_layers=false \ + weight_dtype=bfloat16 \ + quantization=int8 \ + quantize_kvcache=True \ No newline at end of file diff --git a/vm_scipts/xla_auto_tuning.sh b/vm_scipts/xla_auto_tuning.sh new file mode 100644 index 000000000..085e473c1 --- /dev/null +++ b/vm_scipts/xla_auto_tuning.sh @@ -0,0 +1,54 @@ +for xla_gpu_enable_pipelined_all_gather in true false +do + for xla_gpu_enable_pipelined_reduce_scatter in true false + do + for xla_gpu_enable_pipelined_all_reduce in true false + do + for xla_gpu_enable_while_loop_double_buffering in true false + do + for xla_gpu_enable_triton_softmax_fusion in true false + do + for xla_gpu_enable_all_gather_combine_by_dim in true false + do + for xla_gpu_enable_reduce_scatter_combine_by_dim in true false + do + export XLA_FLAGS="--xla_dump_to=/tmp/HLO_dumps/ --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=true + --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true + --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 + --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=${xla_gpu_enable_pipelined_all_gather} + --xla_gpu_enable_pipelined_reduce_scatter=${xla_gpu_enable_pipelined_reduce_scatter} --xla_gpu_enable_pipelined_all_reduce=${xla_gpu_enable_pipelined_all_reduce} + --xla_gpu_enable_while_loop_double_buffering=${xla_gpu_enable_while_loop_double_buffering} --xla_gpu_enable_triton_softmax_fusion=${xla_gpu_enable_triton_softmax_fusion} + --xla_gpu_enable_all_gather_combine_by_dim=${xla_gpu_enable_all_gather_combine_by_dim} --xla_gpu_enable_reduce_scatter_combine_by_dim=${xla_gpu_enable_reduce_scatter_combine_by_dim} + --xla_disable_hlo_passes=rematerialization" + echo ${XLA_FLAGS} + export TF_FORCE_GPU_ALLOW_GROWTH=true + export BASE_OUTPUT_DIRECTORY=gs://jwyang/maxtext + export ASYNC_CHECKPOINTING=false + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 + export PER_DEVICE_BATCH_SIZE=140 + python3 MaxText/inference_microbenchmark.py MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + model_name='llama2-70b' \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + attention=dot_product \ + scan_layers=false \ + hardware=gpu \ + async_checkpointing=${ASYNC_CHECKPOINTING} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=64 \ + run_name=$(date +%Y-%m-%d-%H-%M) \ + ici_fsdp_parallelism=1 \ + ici_autoregressive_parallelism=-1 \ + ici_tensor_parallelism=1 \ + weight_dtype=bfloat16 \ + quantization=int8 quantize_kvcache=True |& tee ${xla_gpu_enable_pipelined_all_gather}_${xla_gpu_enable_pipelined_reduce_scatter}_${xla_gpu_enable_pipelined_all_reduce}_${xla_gpu_enable_while_loop_double_buffering}_${xla_gpu_enable_triton_softmax_fusion}_${xla_gpu_enable_all_gather_combine_by_dim}_${xla_gpu_enable_reduce_scatter_combine_by_dim}.txt + done + done + done + done + done + done +done \ No newline at end of file From d0c0356d09057792904d5769c97b38fab8b89801 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Mon, 2 Dec 2024 22:03:08 +0000 Subject: [PATCH 11/18] test aqt fp8 integration --- MaxText/layers/quantizations.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index d21228ff5..a5195dc12 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -353,7 +353,9 @@ def einsum_fn_with_rhs_qtensor( self, kv: Array| aqt_tensor.QTensor, rhs_dequant_mode=None, - rhs_calibration_mode=None + rhs_calibration_mode=None, + lhs_dequant_mode=None, + lhs_calibration_mode=None, ): # Assumes kv is already quantized. einsum = jnp.einsum @@ -379,6 +381,15 @@ def einsum_fn_with_rhs_qtensor( kv_cfg, rhs_calibration_mode=rhs_calibration_mode, ) + if lhs_dequant_mode: + aqt_config.set_fwd_dequant_mode( + kv_cfg, lhs_dequant_mode=lhs_dequant_mode + ) + if lhs_calibration_mode: + aqt_config.set_fwd_calibration_mode( + kv_cfg, + lhs_calibration_mode=lhs_calibration_mode, + ) einsum = aqt_flax.AqtEinsum( rhs_quant_mode=aqt_flax.QuantMode.TRAIN, lhs_freeze_mode=aqt_flax.FreezerMode.NONE, @@ -391,6 +402,8 @@ def einsum_fn_with_rhs_qtensor( def einsum_fn_with_rhs_qtensor_and_dequant(self, value): return self.einsum_fn_with_rhs_qtensor( value, + lhs_dequant_mode=aqt_config.DequantMode.THIS_INPUT, + lhs_calirbation_mode=aqt_config.CalibrationMode.REMAINING_AXIS, rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT, rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS ) \ No newline at end of file From 8d4f203f0968bf93c0b3785b3c37ade947d7aff4 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Mon, 2 Dec 2024 22:03:56 +0000 Subject: [PATCH 12/18] fix --- MaxText/layers/attentions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 4b202a38c..7fdf52783 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -419,9 +419,9 @@ def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_m """ einsum = jnp.einsum if self.kv_quant: - # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) - if isinstance(key, KVTensor): - key = self.kv_quant.dequant(key) + einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) + # if isinstance(key, KVTensor): + # key = self.kv_quant.dequant(key) b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads @@ -461,9 +461,9 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s einsum = jnp.einsum if self.kv_quant: - # einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) - if isinstance(value, KVTensor): - value = self.kv_quant.dequant(value) + einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) + # if isinstance(value, KVTensor): + # value = self.kv_quant.dequant(value) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape From f5ff903ec3df23bc37db2032b7eea0805ae2ff2d Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Tue, 21 Jan 2025 19:31:29 +0000 Subject: [PATCH 13/18] fix --- MaxText/layers/quantizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index a5195dc12..2dc22f1e6 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -403,7 +403,7 @@ def einsum_fn_with_rhs_qtensor_and_dequant(self, value): return self.einsum_fn_with_rhs_qtensor( value, lhs_dequant_mode=aqt_config.DequantMode.THIS_INPUT, - lhs_calirbation_mode=aqt_config.CalibrationMode.REMAINING_AXIS, + lhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS, rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT, rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS ) \ No newline at end of file From cbfe5aff8b9eaed41660c6d83d49c6b58265201a Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Wed, 5 Feb 2025 20:01:09 +0000 Subject: [PATCH 14/18] delete unused scripts. --- vm_scipts/gpu_script.sh | 15 -- vm_scipts/jax_script.py | 34 ---- vm_scipts/tpu_script.sh | 382 ----------------------------------- vm_scipts/xla_auto_tuning.sh | 54 ----- 4 files changed, 485 deletions(-) delete mode 100644 vm_scipts/gpu_script.sh delete mode 100644 vm_scipts/jax_script.py delete mode 100644 vm_scipts/tpu_script.sh delete mode 100644 vm_scipts/xla_auto_tuning.sh diff --git a/vm_scipts/gpu_script.sh b/vm_scipts/gpu_script.sh deleted file mode 100644 index 8fcd0b5b8..000000000 --- a/vm_scipts/gpu_script.sh +++ /dev/null @@ -1,15 +0,0 @@ -GPU_NAME="vipannalla-mlperf-v41-a3-may" -ZONE="us-central1-a" -PROJECT="cloud-tpu-inference-test" - - -ssh_to_gpu() { - gcloud compute ssh --zone ${ZONE} ${GPU_NAME} --project ${PROJECT} -- -o ProxyCommand='corp-ssh-helper %h %p' -} - -copy_maxtext_files() { - gcloud compute scp --zone ${ZONE} --project ${PROJECT} \ - $PWD/MaxText/profiler.py \ - ${GPU_NAME}:/scratch/jwyang-workspace/maxtext/MaxText/ \ - --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" -} \ No newline at end of file diff --git a/vm_scipts/jax_script.py b/vm_scipts/jax_script.py deleted file mode 100644 index 83efd7738..000000000 --- a/vm_scipts/jax_script.py +++ /dev/null @@ -1,34 +0,0 @@ - - -import jax -import jax.numpy as jnp -import time - -import torch - -@jax.jit -def multiply_fusion_test(): - x = jnp.ones((1024, 1, 1280, 128), dtype=jnp.float8_e4m3fn) - y = jnp.ones((1024, 1, 1280, 128), dtype=jnp.bfloat16) - out = jnp.bfloat16(x) * y - return out - -def torch_fusion_test(): - x = torch.ones((1024, 1, 1280, 128), dtype=torch.float8_e4m3fn) - y = torch.ones((1024, 1, 1280, 128), dtype=torch.bfloat16) - out = x.to(torch.bfloat16) * y - return out - -multiply_fusion_test() # warmup run to compile -start_time = time.time_ns() -for i in range(100): - multiply_fusion_test() -end_time = time.time_ns() -print("number of ms taken for jax: ", (end_time - start_time)/(100 * 1000000)) - -assert torch.cuda.is_available() -start_time = time.time_ns() -for i in range(100): - torch_fusion_test() -end_time = time.time_ns() -print("number of ms taken for pytorch: ", (end_time - start_time)/(100 * 1000000)) \ No newline at end of file diff --git a/vm_scipts/tpu_script.sh b/vm_scipts/tpu_script.sh deleted file mode 100644 index c6635652f..000000000 --- a/vm_scipts/tpu_script.sh +++ /dev/null @@ -1,382 +0,0 @@ -#!/bin/bash - - -#!/bin/bash - -# Multi-Host vlp (TODO: replace these params for your own config) -NAME="jwyang-tpu-sh1" -# NAME="jwyang-v5p8-vm" -ACCELERATOR_TYPE="v5litepod-4" -# ACCELERATOR_TYPE="v5litepod-8" -# ACCELERATOR_TYPE="v5p-8" -RUNTIME_VERSION="v2-alpha-tpuv5-lite" -# PROJECT="tpu-prod-env-automated" -PROJECT="cloud-tpu-inference-test" -# PROJECT="tpu-prod-env-small" -# PROJECT="tpu-prod-env-large-cont" -# ZONE="us-east1-c" -ZONE="us-west1-c" -# ZONE="us-east5-a" - -USER=jwyang - -# (TODO: replace these params to your own config) -NUM_WORKERS=1 -TPU_NAME="t1v-n-63d3a09c" - -create_tpu() { - # A temporary solution to clean up the failed and suspended queued resources. - # Otherwise, there will be a quota error. - existing_qr=$(gcloud alpha compute tpus queued-resources list \ - --project ${PROJECT} \ - --zone ${ZONE} \ - --quiet) - while read -r line; do - name=$(echo $line | awk '{print $1}') - status=$(echo $line | awk '{print $5}') - echo ${name} - echo ${status} - if [[ ${status} == "SUSPENDED" || ${status} == "FAILED" ]]; then - gcloud alpha compute tpus queued-resources delete ${name} \ - --project ${PROJECT} \ - --zone ${ZONE} \ - --quiet - fi - done <<< ${existing_qr} - - gcloud alpha compute tpus queued-resources create ${NAME} \ - --description noteardown \ - --node-id ${NAME} \ - --project=${PROJECT} \ - --zone=${ZONE} \ - --accelerator-type=${ACCELERATOR_TYPE} \ - --runtime-version=${RUNTIME_VERSION} \ - --reserved; -} - -list_tpu() { - gcloud compute tpus tpu-vm list --project=${PROJECT} --zone=${ZONE}; -} - -list_queue_resource() { - gcloud alpha compute tpus queued-resources list --project=${PROJECT} --zone=${ZONE}; -} - -delete_tpu() { - gcloud alpha compute tpus tpu-vm delete ${NAME} --project=${PROJECT} --zone=${ZONE}; - gcloud alpha compute tpus queued-resources delete ${NAME} --project=${PROJECT} --zone=${ZONE}; -} - -ssh_to_tpu() { - gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${1} --project ${PROJECT} -- -o ProxyCommand='corp-ssh-helper %h %p' -} - -create_disk() { - for ((i = 0; i < ${NUM_WORKERS}; i++)); do - TPU_WORKER_NAME=${TPU_NAME}-w-${i} - DISK_NAME=${NAME}-w${i}-ssd - - SIZE=35 - if [[ ${i} == 0 ]] - then - SIZE=512 - fi - - gcloud compute disks create ${DISK_NAME} \ - --size ${SIZE} \ - --zone ${ZONE} \ - --type pd-ssd \ - --project=${PROJECT} - - # attach disk to tpu - gcloud alpha compute instances attach-disk ${TPU_WORKER_NAME} \ - --zone=${ZONE} \ - --disk=${DISK_NAME} \ - --mode=rw \ - --project=${PROJECT} - - gcloud compute instances set-disk-auto-delete ${TPU_WORKER_NAME} \ - --zone=${ZONE} \ - --auto-delete \ - --disk=${DISK_NAME} \ - --project=${PROJECT} - - gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ - --command="sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb && - sudo mkdir -p /mnt/disks/persist && - sudo mount -o discard,defaults /dev/sdb /mnt/disks/persist" \ - -- -o ProxyCommand='corp-ssh-helper %h %p' - done -} - -detach_disks() { - for ((i = 0; i < ${NUM_WORKERS}; i++)); do - TPU_WORKER_NAME=${TPU_NAME}-w-${i} - DISK_NAME=${NAME}-w${i}-ssd - - # attach disk to tpu - gcloud alpha compute instances detach-disk ${TPU_WORKER_NAME} \ - --zone=${ZONE} \ - --disk=${DISK_NAME} \ - --project=${PROJECT} - done -} - -check_disks() { - set -o xtrace - dir_checks="" - for ((i = 0; i < ${NUM_WORKERS}; i++)); do - dir_checks="$dir_checks $( - gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ - --command="if [ -d /mnt/disks/persist ]; then echo "exists"; fi" \ - -- -o ProxyCommand='corp-ssh-helper %h %p' - )" - done - num_dir_exists=$(echo "$dir_checks" | wc -w) - echo "Number of workers with disks: $num_dir_exists" - set +o xtrace -} - - - -############### Scrach ################ -copy_relevant_files() { - # # kill model server process - # gcloud compute tpus tpu-vm ssh ${NAME} --zone=${ZONE} --worker=all --project=${PROJECT} \ - # --command="sudo rm /tmp/libtpu_lockfile && sudo lsof -t /dev/vfio/0 > tpu_process_pid.txt && sudo pkill -F tpu_process_pid.txt" \ - # -- -o ProxyCommand='corp-ssh-helper %h %p' - - - # gcloud compute tpus tpu-vm \ - # scp --zone=${ZONE} --project=${PROJECT} --worker=all \ - # $PWD/MaxText/maxengine.py \ - # ${NAME}:~/maxtext/MaxText/maxengine.py \ - # --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" - - gcloud compute tpus tpu-vm \ - scp --zone=${ZONE} --project=${PROJECT} --worker=all \ - $PWD/benchmarks/benchmark_serving.py \ - ${NAME}:~/JetStream/benchmarks/benchmark_serving.py \ - --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" - - gcloud compute tpus tpu-vm \ - scp --zone=${ZONE} --project=${PROJECT} --worker=all \ - $PWD/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ - ${NAME}:~/JetStream/benchmarks/ \ - --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" - -} - - -# # Microbenchmark command -# # source .env/bin/activate -# your_run_name=jwyang_bs1_llama7b -# python MaxText/inference_microbenchmark.py \ -# MaxText/configs/base.yml \ -# base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ -# run_name=${your_run_name} \ -# per_device_batch_size=12 \ -# save_config_to_gcs=true \ -# model_name=llama2-7b \ -# tokenizer_path=assets/tokenizer.llama2 \ -# inference_microbenchmark_prefill_lengths=1024 \ -# max_prefill_predict_length=1024 \ -# max_target_length=2048 \ -# ici_fsdp_parallelism=1 \ -# ici_tensor_parallelism=-1 \ -# ici_autoregressive_parallelism=1 \ -# weight_dtype=bfloat16 \ -# enable_profiler=true \ -# scan_layers=false \ -# quantization=int8 \ -# quantize_kvcache=true -# inference_mode=true - - -# LLaMA2-7B JetStream/Maxtext commands -export model_name=llama2-7b -export tokenizer_path=assets/tokenizer.llama2 -export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" -export ici_tensor_parallelism=-1 -export ici_autoregressive_parallelism=1 -export per_device_batch_size=12 -export load_parameters_path_chat=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-18-28/checkpoints/0/items -export load_parameters_path=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-19-40/checkpoints/0/items - -python MaxText/maxengine_server.py \ - MaxText/configs/base.yml \ - base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ - load_parameters_path=${load_parameters_path_chat} \ - run_name=$(date +%Y-%m-%d-%H-%M) \ - save_config_to_gcs=true \ - model_name=${model_name} \ - tokenizer_path=${tokenizer_path} \ - inference_microbenchmark_log_file_path=microbenchmark.json \ - inference_microbenchmark_prefill_lengths=1024 \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=1000 \ - max_prefill_predict_length=1024 \ - max_target_length=2048 \ - per_device_batch_size=${per_device_batch_size} \ - ici_fsdp_parallelism=1 \ - ici_tensor_parallelism=${ici_tensor_parallelism} \ - ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ - enable_profiler=false \ - scan_layers=false \ - weight_dtype=bfloat16 - # quantization=int8 - # quantize_kvcache=True - - -export model_name=llama2-7b -export dataset_path=/home/jwyang/llama7b_chat_openorca_input.json -python JetStream/benchmarks/benchmark_serving.py \ - --tokenizer ~/maxtext/assets/tokenizer.llama2 \ - --warmup-first true \ - --save-result \ - --save-request-outputs \ - --request-outputs-file-path /home/jwyang/outputs.json \ - --num-prompts 1000 \ - --max-output-length 1024 \ - --dataset openorca \ - --dataset-path ${dataset_path} - - - -# # 13b model -export model_name=llama2-13b -export tokenizer_path=assets/tokenizer.llama2 -export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" -export ici_tensor_parallelism=-1 -export ici_autoregressive_parallelism=1 -export per_device_batch_size=1 -export load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items - - -export experiment_time=$(date +%Y-%m-%d-%H-%M) -echo "export experiment_time=${experiment_time}" -python MaxText/maxengine_server.py \ - MaxText/configs/base.yml \ - base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ - model_name=llama2-13b \ - async_checkpointing=false \ - load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ - run_name=${experiment_time} \ - inference_microbenchmark_log_file_path=${run_name}.json \ - tokenizer_path=assets/tokenizer.llama2 \ - weight_dtype=bfloat16 \ - inference_microbenchmark_prefill_lengths=1024 \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=10 \ - max_prefill_predict_length=1024 \ - max_target_length=2048 \ - ici_fsdp_parallelism=1 \ - ici_tensor_parallelism=-1 \ - ici_autoregressive_parallelism=1 \ - enable_profiler=false \ - scan_layers=false \ - attention=dot_product \ - save_config_to_gcs=true \ - per_device_batch_size=1 - - - run_name=$(date +%Y-%m-%d-%H-%M) \ - save_config_to_gcs=true \ - model_name=${model_name} \ - tokenizer_path=${tokenizer_path} \ - inference_microbenchmark_log_file_path=microbenchmark.json \ - inference_microbenchmark_prefill_lengths=1024 \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=1000 \ - max_prefill_predict_length=1024 \ - max_target_length=2048 \ - per_device_batch_size=${per_device_batch_size} \ - ici_fsdp_parallelism=1 \ - ici_tensor_parallelism=${ici_tensor_parallelism} \ - ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ - enable_profiler=false \ - scan_layers=false \ - weight_dtype=bfloat16 - - -python MaxText/inference_microbenchmark.py \ - MaxText/configs/base.yml \ - base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ - model_name=llama2-13b \ - async_checkpointing=false \ - load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ - run_name=${experiment_time} \ - inference_microbenchmark_log_file_path=${run_name}.json \ - tokenizer_path=assets/tokenizer.llama2 \ - weight_dtype=bfloat16 \ - inference_microbenchmark_prefill_lengths=1024 \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=10 \ - max_prefill_predict_length=1024 \ - max_target_length=2048 \ - ici_fsdp_parallelism=1 \ - ici_tensor_parallelism=-1 \ - ici_autoregressive_parallelism=1 \ - enable_profiler=false \ - scan_layers=false \ - attention=dot_product \ - save_config_to_gcs=true \ - per_device_batch_size=1 - - - -# # LLaMA2-70B commands -# # source .env/bin/activate -# your_run_name=jwyang_bs1_llama70b -# python MaxText/inference_microbenchmark.py \ -# MaxText/configs/base.yml \ -# base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ -# run_name=${your_run_name} \ -# per_device_batch_size=1 \ -# save_config_to_gcs=true \ -# model_name=llama2-70b \ -# tokenizer_path=assets/tokenizer.llama2 \ -# inference_microbenchmark_prefill_lengths=32 \ -# max_prefill_predict_length=32 \ -# max_target_length=64 \ -# ici_fsdp_parallelism=1 \ -# ici_tensor_parallelism=-1 \ -# ici_autoregressive_parallelism=1 \ -# weight_dtype=bfloat16 \ -# enable_profiler=true \ -# scan_layers=false \ -# quantization=int8 \ -# quantize_kvcache=true - - -export model_name=llama2-70b -export tokenizer_path=assets/tokenizer.llama2 -export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" -export ici_tensor_parallelism=-1 -export ici_autoregressive_parallelism=1 -export per_device_batch_size=1 -export prefill_length=16 -export target_length=32 - -python MaxText/maxengine_server.py \ - MaxText/configs/base.yml \ - base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ - run_name=$(date +%Y-%m-%d-%H-%M) \ - save_config_to_gcs=true \ - model_name=${model_name} \ - tokenizer_path=${tokenizer_path} \ - inference_microbenchmark_log_file_path=microbenchmark.json \ - inference_microbenchmark_prefill_lengths=${prefill_length} \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=1000 \ - max_prefill_predict_length=${prefill_length} \ - max_target_length=${target_length} \ - per_device_batch_size=${per_device_batch_size} \ - ici_fsdp_parallelism=1 \ - ici_tensor_parallelism=${ici_tensor_parallelism} \ - ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ - enable_profiler=false \ - scan_layers=false \ - weight_dtype=bfloat16 \ - quantization=int8 \ - quantize_kvcache=True \ No newline at end of file diff --git a/vm_scipts/xla_auto_tuning.sh b/vm_scipts/xla_auto_tuning.sh deleted file mode 100644 index 085e473c1..000000000 --- a/vm_scipts/xla_auto_tuning.sh +++ /dev/null @@ -1,54 +0,0 @@ -for xla_gpu_enable_pipelined_all_gather in true false -do - for xla_gpu_enable_pipelined_reduce_scatter in true false - do - for xla_gpu_enable_pipelined_all_reduce in true false - do - for xla_gpu_enable_while_loop_double_buffering in true false - do - for xla_gpu_enable_triton_softmax_fusion in true false - do - for xla_gpu_enable_all_gather_combine_by_dim in true false - do - for xla_gpu_enable_reduce_scatter_combine_by_dim in true false - do - export XLA_FLAGS="--xla_dump_to=/tmp/HLO_dumps/ --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=true - --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true - --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 - --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=${xla_gpu_enable_pipelined_all_gather} - --xla_gpu_enable_pipelined_reduce_scatter=${xla_gpu_enable_pipelined_reduce_scatter} --xla_gpu_enable_pipelined_all_reduce=${xla_gpu_enable_pipelined_all_reduce} - --xla_gpu_enable_while_loop_double_buffering=${xla_gpu_enable_while_loop_double_buffering} --xla_gpu_enable_triton_softmax_fusion=${xla_gpu_enable_triton_softmax_fusion} - --xla_gpu_enable_all_gather_combine_by_dim=${xla_gpu_enable_all_gather_combine_by_dim} --xla_gpu_enable_reduce_scatter_combine_by_dim=${xla_gpu_enable_reduce_scatter_combine_by_dim} - --xla_disable_hlo_passes=rematerialization" - echo ${XLA_FLAGS} - export TF_FORCE_GPU_ALLOW_GROWTH=true - export BASE_OUTPUT_DIRECTORY=gs://jwyang/maxtext - export ASYNC_CHECKPOINTING=false - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 - export PER_DEVICE_BATCH_SIZE=140 - python3 MaxText/inference_microbenchmark.py MaxText/configs/base.yml \ - base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - model_name='llama2-70b' \ - max_prefill_predict_length=1024 \ - max_target_length=2048 \ - attention=dot_product \ - scan_layers=false \ - hardware=gpu \ - async_checkpointing=${ASYNC_CHECKPOINTING} \ - per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ - inference_microbenchmark_prefill_lengths=1024 \ - inference_microbenchmark_stages=prefill,generate \ - inference_microbenchmark_loop_iters=64 \ - run_name=$(date +%Y-%m-%d-%H-%M) \ - ici_fsdp_parallelism=1 \ - ici_autoregressive_parallelism=-1 \ - ici_tensor_parallelism=1 \ - weight_dtype=bfloat16 \ - quantization=int8 quantize_kvcache=True |& tee ${xla_gpu_enable_pipelined_all_gather}_${xla_gpu_enable_pipelined_reduce_scatter}_${xla_gpu_enable_pipelined_all_reduce}_${xla_gpu_enable_while_loop_double_buffering}_${xla_gpu_enable_triton_softmax_fusion}_${xla_gpu_enable_all_gather_combine_by_dim}_${xla_gpu_enable_reduce_scatter_combine_by_dim}.txt - done - done - done - done - done - done -done \ No newline at end of file From de43f817b645ae7e1c7c9732d164659caa68e8ef Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Wed, 5 Feb 2025 20:50:49 +0000 Subject: [PATCH 15/18] fix --- MaxText/layers/quantizations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index acc83e364..104f0407b 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -270,10 +270,6 @@ def _get_quant_config(config): return _get_int8_quant_config(config) if config.quantization == "aqt_fp8": return _get_aqt_fp8_quant_config(config) - if config.quantization == "int8w": - return _get_weight_only_quant_config(lhs_bits=None, rhs_bits=8) - if config.quantization == "int4w": - return _get_weight_only_quant_config(lhs_bits=None, rhs_bits=4) if config.quantization == "intmp": assert config.quant_cfg_path, "Must specify quant_cfg for mixed precision quantization" with open(config.quant_cfg_path, "r") as config_file: From 738d93cec083c562a597fedc5f29a415a58bfcd4 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Wed, 5 Feb 2025 20:53:01 +0000 Subject: [PATCH 16/18] fix --- MaxText/profiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/profiler.py b/MaxText/profiler.py index 74cc2b041..10a115025 100644 --- a/MaxText/profiler.py +++ b/MaxText/profiler.py @@ -57,8 +57,9 @@ def activate(self, blocking_object=None, optional_postfix=""): return self.libcudart.cudaProfilerStart() elif self.mode == "xplane": + # when running on GPU, seems the gs bucket upload fails if self.output_path.startswith("gs://"): - self.output_path = "/scratch/jwyang-workspace/jax-profiles/" + self.output_path[4:] + self.output_path = "/tmp/jax-profiles/" + self.output_path[4:] print("Start profiling {}".format(self.output_path)) jax.profiler.start_trace(self.output_path) From 26427084f8b5386ca1db2320f13f4f9d4e475ed5 Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Wed, 5 Feb 2025 20:55:21 +0000 Subject: [PATCH 17/18] fix --- MaxText/maxengine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index c0c5835e1..9af3ffd10 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -815,7 +815,8 @@ def set_engine_vars_from_base_engine( """Set internal vars from base_engine, which has already loaded the checkpoint and has sharding, mesh, and kv cache related vars set. """ - engine.model.quant.quant_mode = base_engine.model.quant.quant_mode + if base_engine.model.quant is not None: + engine.model.quant.quant_mode = base_engine.model.quant.quant_mode engine.state_mesh_annotations = base_engine.state_mesh_annotations engine.abstract_params = base_engine.abstract_params engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access From deef5f2a99f3f9705431159ec3758981c282cb9d Mon Sep 17 00:00:00 2001 From: jwyang-google Date: Fri, 7 Feb 2025 21:59:23 +0000 Subject: [PATCH 18/18] fp8 inference fix --- MaxText/layers/attentions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 50e6e0a3c..0b5b0bd30 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -520,8 +520,6 @@ def qk_product(self, query: Array, key: Array | KVTensor, q_seq_len: int, model_ einsum = jnp.einsum if self.kv_quant: einsum = self.kv_quant.einsum_fn_with_rhs_qtensor(key) - # if isinstance(key, KVTensor): - # key = self.kv_quant.dequant(key) b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads @@ -561,6 +559,8 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s einsum = jnp.einsum if self.kv_quant: + if isinstance(value, KVTensor) and self.kv_quant.dtype == jnp.float8_e4m3fn: + value.qvalue = value.qvalue.astype(jnp.bfloat16) einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3): out = einsum("bkgts,bskd->btkgd", attn_weights, value)