Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cross entropy loss for xla autocast with FP32 precision #7992

Merged
merged 3 commits into from
Sep 26, 2024

Conversation

avizon-aws
Copy link
Collaborator

There are many operators in XLA autocast that have been commented, but these operators are casted in the GPU, in order to maintain consistency, we need to support these operators as well. For cross_entropy_loss, it is currently commented in the xla autocast, so there will be no casting occuring, and it will execute based on its input’s dtype.

The output type is bf16, which is expected because linear layer is specified in xla autocast. loss dtype is fp32, which is correct, but there’s a catch, there was no autocasting done for the crossEntropyLoss, the reason the dtype is FP32 is because of the target’s dtype, which is FP32. There is a multiplication which happens in crossentropyloss between the generated output and the target, all the exponentiation/log etc. is done in BF16, but only because of the final multiplication, we get the result in FP32, because it casts to the higher precision (FP32). This is not the expected behavior, all the exponentiation/logs i.e. all ops related to crossentropyloss should be executed in FP32, the reason it is not happening is because crossentropyloss is not specified in xla autocast. This finding is based after detailed analysis of the HLO outputs which is attached below.

Before this change:
Exp1.

device = 'xla' # Get the XLA device (e.g., TPU or GPU)
model = torch.nn.Linear(10, 10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
data = torch.randn(16, 10).to(torch.bfloat16).(to(device)
target = torch.randn(16, 10).to(device)
print(device, torch.__version__)
for epoch in range(1):
    optimizer.zero_grad()
    # debugpy.breakpoint() 
    with torch.autocast('xla'):
        output = model(data)
        loss = torch.nn.CrossEntropyLoss()(output, target)
        print(output.dtype, loss.dtype, target.dtype)
    # loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

HLO:

ENTRY %SyncTensorsGraph.62 (p0.1: f32[], p1.2: f32[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (f32[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  **%constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %convert.50 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %subtract.49), metadata={op_type="aten__mul" op_name="aten__mul"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.51 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %convert.50, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}**
  %constant.52 = f32[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.58 = f32[] reduce(f32[16,10]{1,0} %multiply.51, f32[] %constant.52), dimensions={0,1}, to_apply=%AddComputation.54, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.59 = f32[] negate(f32[] %reduce.58), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.60 = f32[] divide(f32[] %negate.59, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.61 = (f32[]) tuple(f32[] %divide.60), frontend_attributes={neff_output_names="output0"}
}

Exp2.
The target dtype if also bf16 in this case. This experiment was done to prove that the dtype of the target was the true cause of the FP32 output as shown below.

Code change from previous experiment:
target = torch.randn(16, 10).to(torch.bfloat16).to(device)

HLO

ENTRY %SyncTensorsGraph.61 (p0.1: bf16[], p1.2: bf16[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (bf16[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = bf16[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.50 = bf16[16,10]{1,0} multiply(bf16[16,10]{1,0} %subtract.49, bf16[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %constant.51 = bf16[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.57 = bf16[] reduce(bf16[16,10]{1,0} %multiply.50, bf16[] %constant.51), dimensions={0,1}, to_apply=%AddComputation.53, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.58 = bf16[] negate(bf16[] %reduce.57), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = bf16[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.59 = bf16[] divide(bf16[] %negate.58, bf16[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.60 = (bf16[]) tuple(bf16[] %divide.59), frontend_attributes={neff_output_names="output0"}
}

After uncommenting the CrossEntropyLoss in the XLA autocast as done in this PR:

Exp3:
The input and target are in FP32, so the output of the linear layer will be in BF16, and then it should be upcasted to FP32 for the Crossentropyloss as seen in the HLO.

  %p4.8 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.9 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.8), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.5 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.6 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %p3.5), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.7 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.6), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.10 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.9, bf16[10,10]{0,1} %transpose.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.4 = bf16[10]{0} convert(f32[10]{0} %p2.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.14 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.4), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.15 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.10, bf16[16,10]{1,0} %broadcast.14), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %convert.16 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %add.15), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.17 = f32[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.22 = f32[16]{0} reduce(f32[16,10]{1,0} %convert.16, f32[] %constant.17), dimensions={1}, to_apply=%MaxComputation.18, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.23 = f32[16,10]{1,0} broadcast(f32[16]{0} %reduce.22), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.24 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %convert.16, f32[16,10]{1,0} %broadcast.23), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.25 = f32[16,10]{1,0} exponential(f32[16,10]{1,0} %subtract.24), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.26 = f32[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.31 = f32[16]{0} reduce(f32[16,10]{1,0} %exponential.25, f32[] %constant.26), dimensions={1}, to_apply=%AddComputation.27, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.32 = f32[16]{0} log(f32[16]{0} %reduce.31), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.33 = f32[16,10]{1,0} broadcast(f32[16]{0} %log.32), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.34 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %subtract.24, f32[16,10]{1,0} %broadcast.33), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.35 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %subtract.34, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %reduce.42 = f32[] reduce(f32[16,10]{1,0} %multiply.35, f32[] %constant.26), dimensions={0,1}, to_apply=%AddComputation.38, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.43 = f32[] negate(f32[] %reduce.42), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.44 = f32[] divide(f32[] %negate.43, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.45 = (f32[]) tuple(f32[] %divide.44), frontend_attributes={neff_output_names="output0"}
}

The executions of the operators are as expected.

@JackCaoG JackCaoG merged commit e438a5b into r2.1_aws_neuron Sep 26, 2024
1 check passed
@JackCaoG JackCaoG deleted the autocast_bf16_neuron branch September 26, 2024 16:11
@jeffhataws
Copy link
Collaborator

@avizon-aws Need cherry-pick to master

@jeffhataws
Copy link
Collaborator

Also request cherry-pick to 2.5 here #7977 (comment)

@avizon-aws
Copy link
Collaborator Author

Created PR to master: #8094

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants