jeffbolznv commited on
Commit
4e46f41
·
1 Parent(s): 9dcb047

vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783)

Browse files
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -330,9 +330,11 @@ void main() {
330
  // resize eM by using smear/reduce
331
  coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
332
 
333
- O = eMdiag * O;
 
 
334
 
335
- O = coopMatMulAdd(P_A, V, O);
336
  }
337
 
338
  // If there is split_k, then the split_k resolve shader does the final
 
330
  // resize eM by using smear/reduce
331
  coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
332
 
333
+ // multiply with fp16 accumulation, then add to O.
334
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
335
+ PV = coopMatMulAdd(P_A, V, PV);
336
 
337
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
338
  }
339
 
340
  // If there is split_k, then the split_k resolve shader does the final