diff --git a/pyproject.toml b/pyproject.toml index 667072cb..f9a4bc34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tensorboard-plugin-profile==2.18.0", "tf_keras==2.18.0", "protobuf==4.25.5", + "fire", ] [project.optional-dependencies] diff --git a/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json b/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json index 48b5c719..38c76296 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json +++ b/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json @@ -18,5 +18,5 @@ "qk_nope_head_dim": 128, "qk_rope_head_dim": 64, "v_head_dim": 128, - "dtype": "fp8" + "dtype": "bfloat16" } \ No newline at end of file diff --git a/torchprime/experimental/torchax_models/deepseek_v3/model.py b/torchprime/experimental/torchax_models/deepseek_v3/model.py index 57248bd4..eb0b264b 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/model.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/model.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import Literal +import jax import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn @@ -382,7 +382,7 @@ def __init__(self, args: ModelArgs): def forward( self, x: torch.Tensor, - start_pos: int, + input_pos: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor | None, ): @@ -399,7 +399,6 @@ def forward( torch.Tensor: Output tensor with the same shape as the input. """ bsz, seqlen, _ = x.size() - end_pos = start_pos + seqlen q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x))) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope, q_pe = torch.split( @@ -417,12 +416,9 @@ def forward( ) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) - self.k_cache[:bsz, start_pos:end_pos] = k - self.v_cache[:bsz, start_pos:end_pos] = v - scores = ( - torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) - * self.softmax_scale - ) + # self.k_cache[:bsz, start_pos:end_pos] = k + # self.v_cache[:bsz, start_pos:end_pos] = v + scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale else: wkv_b = ( self.wkv_b.weight @@ -431,19 +427,22 @@ def forward( ) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]) - self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) - self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + # self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) + # self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + kv_cache = self.kv_norm(kv) + pe_cache = k_pe.squeeze(2) scores = ( - torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) - + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) + torch.einsum("bshc,btc->bsht", q_nope, kv_cache) + + torch.einsum("bshr,btr->bsht", q_pe, pe_cache) ) * self.softmax_scale if mask is not None: scores += mask.unsqueeze(1) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) if attn_impl == "naive": - x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) + x = torch.einsum("bsht,bthd->bshd", scores, v) else: - x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) + kv_cache = self.kv_norm(kv) + x = torch.einsum("bsht,btc->bshc", scores, kv_cache) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) x = self.wo(x.flatten(2)) return x @@ -544,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] + print('i am here') mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) scores = (scores * mask.unsqueeze(-1)).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] @@ -590,71 +590,79 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class MoE(nn.Module): - """ - Mixture-of-Experts (MoE) module. - - Attributes: - dim (int): Dimensionality of input features. - n_routed_experts (int): Total number of experts in the model. - n_local_experts (int): Number of experts handled locally in distributed systems. - n_activated_experts (int): Number of experts activated for each input. - gate (nn.Module): Gating mechanism to route inputs to experts. - experts (nn.ModuleList): List of expert modules. - shared_experts (nn.Module): Shared experts applied to all inputs. - """ - - def __init__(self, args: ModelArgs): - """ - Initializes the MoE module. - - Args: - args (ModelArgs): Model arguments containing MoE parameters. - """ +class ConditionalFeedForward(torch.nn.Module): + def __init__(self, config): super().__init__() - self.dim = args.dim - assert args.n_routed_experts % world_size == 0 - self.n_routed_experts = args.n_routed_experts - self.n_local_experts = args.n_routed_experts // world_size - self.n_activated_experts = args.n_activated_experts - self.experts_start_idx = rank * self.n_local_experts - self.experts_end_idx = self.experts_start_idx + self.n_local_experts - self.gate = Gate(args) - self.experts = nn.ModuleList( - [ - Expert(args.dim, args.moe_inter_dim) - if self.experts_start_idx <= i < self.experts_end_idx - else None - for i in range(self.n_routed_experts) - ] + # TODO(How to enable quantization?) + self.w1 = nn.Parameter( + torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.n_routed_experts, config.dim, config.moe_inter_dim) + ) + self.w3 = nn.Parameter( + torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) ) - self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) + self.config = config + + def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: + return self.forward_for_long_seq_len(x, expert_indices) + + def forward_for_long_seq_len(self, x, expert_indices): + seqlen = x.shape[0] + self.w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1)) + x3 = torch.einsum("ti, eoi-> teo", x, self.w3) + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) + # e = 8; need to reduce to 2 + seq_indexes = torch.arange(seqlen, device=x.device).unsqueeze(1) + return expert_outs[seq_indexes, expert_indices] + + +class MoE(torch.nn.Module): + def __init__(self, model_args) -> None: + super().__init__() + self.dim = model_args.dim + self.model_args = model_args + # assert args.n_routed_experts % world_size == 0 + # self.n_routed_experts = args.n_routed_experts + # self.n_local_experts = args.n_routed_experts // world_size + # self.n_activated_experts = args.n_activated_experts + # self.experts_start_idx = rank * self.n_local_experts + # self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(model_args) + # self.experts = nn.ModuleList( + # [ + # Expert(args.dim, args.moe_inter_dim) + # if self.experts_start_idx <= i < self.experts_end_idx + # else None + # for i in range(self.n_routed_experts) + # ] + # ) + self.shared_experts = MLP( + model_args.dim, model_args.n_shared_experts * model_args.moe_inter_dim + ) + self.cond_ffn = ConditionalFeedForward(model_args) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass for the MoE module. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor after expert routing and computation. - """ - shape = x.size() + bsz, seq, hidden = x.shape + # [B, T, D], combine BT, for prefill B = 1, for decode, T = 1 x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + self.gate(x) # [T, E] weights, indices = self.gate(x) - y = torch.zeros_like(x) - counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() - for i in range(self.experts_start_idx, self.experts_end_idx): - if counts[i] == 0: - continue - expert = self.experts[i] - idx, top = torch.where(indices == i) - y[idx] += expert(x[idx]) * weights[idx, top, None] - z = self.shared_experts(x) - if world_size > 1: - dist.all_reduce(y) - return (y + z).view(shape) + expert_outs = self.cond_ffn(x, indices) + expert_outs = torch.einsum("tai,ta -> ti", expert_outs, weights) + # Changes back to [B, T, D] + expert_outs = expert_outs.reshape(bsz, seq, hidden) + return expert_outs class Block(nn.Module): @@ -687,7 +695,7 @@ def __init__(self, layer_id: int, args: ModelArgs): def forward( self, x: torch.Tensor, - start_pos: int, + input_pos: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor | None, ) -> torch.Tensor: @@ -703,7 +711,7 @@ def forward( Returns: torch.Tensor: Output tensor after block computation. """ - x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) + x = x + self.attn(self.attn_norm(x), input_pos, freqs_cis, mask) x = x + self.ffn(self.ffn_norm(x)) return x @@ -728,9 +736,6 @@ def __init__(self, args: ModelArgs): Args: args (ModelArgs): Model arguments containing transformer parameters. """ - global world_size, rank - world_size = dist.get_world_size() if dist.is_initialized() else 1 - rank = dist.get_rank() if dist.is_initialized() else 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len @@ -742,10 +747,10 @@ def __init__(self, args: ModelArgs): self.head = ColumnParallelLinear( args.dim, args.vocab_size, dtype=torch.get_default_dtype() ) - self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) + self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=True) @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int = 0): + def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): """ Forward pass for the Transformer model. @@ -756,18 +761,14 @@ def forward(self, tokens: torch.Tensor, start_pos: int = 0): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - seqlen = tokens.size(1) + tokens.size(1) h = self.embed(tokens) - freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + freqs_cis = self.freqs_cis[input_pos] mask = None - if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + # if seqlen > 1: + # mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) + h = layer(h, input_pos, freqs_cis, mask) h = self.norm(h)[:, -1] logits = self.head(h) - if world_size > 1: - all_logits = [torch.empty_like(logits) for _ in range(world_size)] - dist.all_gather(all_logits, logits) - logits = torch.cat(all_logits, dim=-1) return logits diff --git a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py index 110c69a1..690c5fba 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py @@ -1,41 +1,195 @@ import functools +import json import time import jax +import jax.numpy as jnp +import model as ds_model import torch import torchax import torchax.interop +import torchax.ops.mappings as tx_mappings +from jax.experimental.mesh_utils import create_device_mesh +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from model import ModelArgs, Transformer +from torchax import interop from torchax.interop import JittableModule -from .model import ( - ModelArgs, - Transformer, -) - -def single_device_compile(): - print("======= single_device_compile =======") +def _process_sharding_name(name): + """Replace integers in param name with *. + + Presumably all layers should have the same sharding. + """ + + def is_integer(t): + try: + int(t) + return True + # pylint: disable-next=all + except: # noqa: E722 + return False + + tokens = name.split(".") + for i, t in enumerate(tokens): + if is_integer(t): + tokens[i] = "*" + return ".".join(tokens) + + +def _get_sharding_sepc(sharding_map, name): + sharding_spec = sharding_map.get(name) + if sharding_spec is not None: + return sharding_spec + sharding_spec = sharding_map.get(_process_sharding_name(name)) + return sharding_spec + + +def make_weight_shard(weight_meta, slice_index): + weight_shard_meta = weight_meta[slice_index] + with torchax.default_env(): + return interop.jax_view( + torch.randn(weight_shard_meta.shape, dtype=weight_shard_meta.dtype) + ) + + +def make_cache_shard(weight_meta, slice_index): + weight_shard_meta = weight_meta[slice_index] + return jnp.zeros( + weight_shard_meta.shape, dtype=tx_mappings.t2j_dtype(weight_shard_meta.dtype) + ) + + +def create_sharded_weights(model, mesh, sharding_map, env): + res = {} + for name, weight_meta in model.state_dict().items(): + sharding_spec = _get_sharding_sepc(sharding_map, name) + if sharding_spec is None: + print("Skipping weight:", name, weight_meta.shape) + continue + sharding = NamedSharding(mesh, P(*sharding_spec)) + res[name] = env.j2t_iso( + jax.make_array_from_callback( + weight_meta.shape, sharding, functools.partial(make_weight_shard, weight_meta) + ) + ) + return res + + +def create_sharded_kv_cache(cache_dict, mesh, env): + res = {} + # shard at num device + sharding = NamedSharding(mesh, P(None, None, name0, None)) + for name, weight_meta in cache_dict.items(): + if name.endswith("_cache"): + res[name] = env.j2t_iso( + jax.make_array_from_callback( + weight_meta.shape, sharding, functools.partial(make_cache_shard, weight_meta) + ) + ) + return res + + +name0 = "tp0" +# name1 = "tp1" +sharding_map_1d_tp = { + "embed.weight": (name0, None), + "layers.*.attn.wq.weight": (None, name0), + "layers.*.attn.wq.bias": (name0,), + "layers.*.attn.wkv_a.weight": (None, None), + "layers.*.attn.kv_norm.weight": (name0,), + "layers.*.attn.wkv_b.weight": (name0, None), + "layers.*.attn.wkv_b.bias": (name0,), + "layers.*.attn.wo.weight": (name0, None), + "layers.*.attn.wo.bias": (name0, None), + + "layers.*.attn.wq_a.weight": (None, None), + "layers.*.attn.q_norm.weight": (), + "layers.*.attn.wq_b.weight": (name0, None), + "layers.*.attn.wq_b.bias": (name0,), + + "layers.*.ffn.w1.weight": (name0, None), + "layers.*.ffn.w1.bias": (name0,), + "layers.*.ffn.w2.weight": (None, name0), + "layers.*.ffn.w2.bias": (name0,), + "layers.*.ffn.w3.weight": (name0, None), + "layers.*.ffn.w3.bias": (name0,), + "layers.*.ffn.cond_ffn.w1": (None, name0, None), + "layers.*.ffn.cond_ffn.w2": (None, None, name0), + "layers.*.ffn.cond_ffn.w3": (None, name0, None), + "layers.*.ffn.gate.weight": (None, name0), + "layers.*.ffn.gate.bias": (name0,), + "layers.*.ffn.shared_experts.w1.weight": (name0, None), + "layers.*.ffn.shared_experts.w1.bias": (name0,), + "layers.*.ffn.shared_experts.w2.weight": (None, name0), + "layers.*.ffn.shared_experts.w2.bias": (name0,), + "layers.*.ffn.shared_experts.w3.weight": (name0, None), + "layers.*.ffn.shared_experts.w3.bias": (name0,), + "layers.*.attn_norm.weight": (name0,), + "layers.*.ffn_norm.weight": (name0,), + "norm.weight": (name0,), + "head.weight": (name0, None), + "head.bias": (name0,), + "freqs_cis": (), +} + + +def _replicate(x, env, mesh): + with jax.default_device(jax.devices("cpu")[0]): + xj = env.to_xla(x).jax() + xj = env.j2t_iso( + jax.make_array_from_callback(xj.shape, NamedSharding(mesh, P()), lambda a: xj) + ) + return xj + + +def main(config=None, seqlen=2048, batch_size=1): + config_dict = None + if config is not None: + with open(config) as f: + config_dict = json.load(f) + + print("======= multi_device =======") torch.set_default_dtype(torch.bfloat16) env = torchax.default_env() + config_dict = config_dict or {} + + env.config.use_torch_native_for_cpu_tensor = False + torch.manual_seed(42) torchax.enable_performance_mode() + torchax.enable_globally() + torchax.default_env().config.debug_print_each_op = True + args = ModelArgs(**config_dict) + args.max_batch_size = 1 - args = ModelArgs() + dev_array = create_device_mesh((len(jax.devices()),), allow_split_physical_axes=True) + mesh = Mesh(dev_array, (name0,)) - with torch.no_grad(), env: - x = torch.randint(0, args.vocab_size, (1, 2048)) - x = x.to("jax") + torch.set_default_device("meta") + with env, torch.device("meta"): model = Transformer(args) - model.to("jax") - model.embed = JittableModule(model.embed) - # for i in range(len(model.layers)): - # model.layers[i] = JittableModule(model.layers[i]) - model.norm = JittableModule(model.norm) - model.head = JittableModule(model.head) + jitted = JittableModule(model) + freqs_cis = ds_model.precompute_freqs_cis(args) + freqs_cis = _replicate(freqs_cis, env, mesh) + jitted.buffers["freqs_cis"] = freqs_cis + + print(model) + caches_dict = create_sharded_kv_cache(jitted.buffers, mesh, env) + sharded_weights = create_sharded_weights(model, mesh, sharding_map_1d_tp, env) + + jitted.params = sharded_weights + jitted.buffers.update(caches_dict) + + with mesh: + x = torch.ones((1, 2048), dtype=torch.int32) + x = _replicate(x, env, mesh) + input_pos = torch.arange(2048, device="jax") for i in range(5): step_start = time.perf_counter() - logits = model(x, 0) + logits = jitted(x, input_pos) jax.block_until_ready(torchax.tensor.t2j(logits)) step_end = time.perf_counter() print( @@ -44,46 +198,21 @@ def single_device_compile(): step_end - step_start, ) - -def single_device_eager(): - print("======= single_device_eager =======") - torch.set_default_dtype(torch.bfloat16) - env = torchax.default_env() - torch.manual_seed(42) - torchax.enable_performance_mode() - - args = ModelArgs() - - with torch.no_grad(), env: - x = torch.randint(0, args.vocab_size, (1, 2048)) - x = x.to("jax") - model = Transformer(args) - model.to("jax") - weights = model.state_dict() - model_forward = functools.partial(torch.func.functional_call, model) - # model_forward = torchax.interop.jax_jit(model_forward) - + x = torch.ones((1, 1), dtype=torch.int32) + x = _replicate(x, env, mesh) + input_pos = torch.arange(2048, 2049, device="jax") for i in range(5): step_start = time.perf_counter() - logits = model_forward(weights, (x, 0)) + logits = jitted(x, input_pos) jax.block_until_ready(torchax.tensor.t2j(logits)) step_end = time.perf_counter() print( i, - "step latency: ", + "decode step latency: ", step_end - step_start, ) -def main(option="single_device_eager"): - if option == "single_device_eager": - single_device_eager() - elif option == "single_device_compile": - single_device_compile() - else: - raise Exception("Invalid option") - - if __name__ == "__main__": import fire diff --git a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py index ac53ef8d..abaebc51 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py @@ -1,23 +1,22 @@ import pytest -# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/75): Fix the failure on torch 2.6, -# then enable the test unconditionally. @pytest.mark.deepseek -def test_single_device_compile(): - from torchprime.experimental.torchax_models.deepseek_v3.prefill_benchmark import ( - single_device_compile, - ) +def test_moe_can_jit(): + import torch + import torchax + import torchax.interop - single_device_compile() + from torchprime.experimental.torchax_models.deepseek_v3 import model as ds_model + torchax.enable_globally() + torch.manual_seed(42) + max_seq_len = 512 # 8192 + with torch.no_grad(): + x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device="jax") + model_args = ds_model.ModelArgs() + model = ds_model.MoE(model_args).to("jax") -# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/75): Fix the failure on torch 2.6, -# then enable the test unconditionally. -@pytest.mark.deepseek -def test_single_device_eager(): - from torchprime.experimental.torchax_models.deepseek_v3.prefill_benchmark import ( - single_device_eager, - ) - - single_device_eager() + jitted = torchax.interop.JittableModule(model) + print(jitted(x)) + torchax.disable_globally() diff --git a/torchprime/launcher/Dockerfile b/torchprime/launcher/Dockerfile index f4678e7c..49a1b6a2 100644 --- a/torchprime/launcher/Dockerfile +++ b/torchprime/launcher/Dockerfile @@ -26,6 +26,7 @@ WORKDIR /workspaces # Install torchax RUN git clone https://github.com/pytorch/xla.git WORKDIR /workspaces/xla/torchax +RUN git checkout hanq_torchax1 RUN pip install torch_xla[pallas] \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html @@ -42,5 +43,7 @@ RUN pip install -e . COPY . /workspaces/torchprime # This should not install any packages. Only symlink the source code. RUN pip install --no-deps -e . +RUN pip install --force-reinstall --upgrade torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu +RUN pip uninstall torchvision -y ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" diff --git a/torchprime/launcher/buildpush.py b/torchprime/launcher/buildpush.py index 51746eca..f50dbff6 100755 --- a/torchprime/launcher/buildpush.py +++ b/torchprime/launcher/buildpush.py @@ -47,7 +47,7 @@ def buildpush( # Build, tag, and push Docker image try: _run( - f"{sudo_cmd} docker build --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}", + f"{sudo_cmd} docker build --no-cache --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}", ) _run( f"{sudo_cmd} docker tag {docker_tag} {docker_url}",