Hi @kyegomez,
Thank you very much for open sourcing this, it is extremely helpful!
I am going through the code and I noticed a difference with the algorithm described in the Flash Attention 2.0 paper.
In Algorithm 1, the authors tell us to multiply the inverse diagonal of exp_row_max_diff
with the partial output oc
.
However, I see that in your implementation, you multiply the partial output oc
with the non-inverted diagonal of exp_row_max_diff
.
Line 104 [here].(https://github.com/kyegomez/FlashAttention20/blob/9ec1d9340f023d9d51037e8ddcbdd2d4d207a001/attention.py#L104C25-L104C41):
oc.mul_(exp_row_max_diff)[...]
I understand that because of the initialization of row_maxes
to -inf
, inverting the exponential values will results in +inf
(or NaN) values, and thus could not work.
So, I wonder, is it an issue in the original algorithm of the paper?
Or did I miss something?
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too