You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here it seems that the hard-coded bfloat16 is used instead of attend_dtype. Also query is not cast. I guess the correct behavior should be casting both query and self.embedding to attend_dtype?
The text was updated successfully, but these errors were encountered:
I think we should figure out if
(a) does doing the dot in f32 help convergence (using the 1B runs)?
(b) does @ZhiyuLi-goog/MLPerf care?
(c) what does Anselm Levskaya think
We should make the code consistent and as simple as possible. Also, why is our pylint/pytype not raising alarms on this, unused vars are bad?
Here it seems that the hard-coded
bfloat16
is used instead ofattend_dtype
. Alsoquery
is not cast. I guess the correct behavior should be casting bothquery
andself.embedding
toattend_dtype
?The text was updated successfully, but these errors were encountered: