Spaces:
Running
Running
Commit
·
4e46f41
1
Parent(s):
9dcb047
vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783)
Browse files
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -330,9 +330,11 @@ void main() {
|
|
| 330 |
// resize eM by using smear/reduce
|
| 331 |
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 332 |
|
| 333 |
-
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
O =
|
| 336 |
}
|
| 337 |
|
| 338 |
// If there is split_k, then the split_k resolve shader does the final
|
|
|
|
| 330 |
// resize eM by using smear/reduce
|
| 331 |
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 332 |
|
| 333 |
+
// multiply with fp16 accumulation, then add to O.
|
| 334 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
| 335 |
+
PV = coopMatMulAdd(P_A, V, PV);
|
| 336 |
|
| 337 |
+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
| 338 |
}
|
| 339 |
|
| 340 |
// If there is split_k, then the split_k resolve shader does the final
|