diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 598a03c758..767ad45a17 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -299,12 +299,6 @@ def q_dot_dq_impl( preferred_element_type, is_training ): - if precision != None or preferred_element_type != None: - warnings.warn( - "The function dot_general_with_precision will set the " - "precision/preferred_element_type and disregard any provided " - "values." - ) new_lhs_scale, new_lhs_amax_history = quantize_and_update( lhs, jnp.float8_e4m3fn,