jeffbolznv commited on
Commit
825889e
·
1 Parent(s): 5db8b21

vulkan: use aligned loads for flash attention mask (llama/12853)

Browse files

Rewrite the stride logic for the mask tensor in the FA shader to force the
stride to be aligned, to allow using more efficient loads.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -201,6 +201,11 @@ void main() {
201
  uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202
  uint32_t k_stride = p.nb11;
203
  uint32_t v_stride = p.nb21;
 
 
 
 
 
204
  // hint to the compiler that strides are aligned for the aligned variant of the shader
205
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
206
  {
@@ -209,6 +214,7 @@ void main() {
209
  k_stride &= ~7;
210
  v_stride &= ~7;
211
  #endif
 
212
  }
213
  tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
214
  tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
261
  if (p.mask != 0) {
262
  tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
263
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
264
- // When using grouped query attention, all rows use the same mask.
265
- if (p.gqa_ratio > 1) {
266
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
267
- }
268
 
269
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
270
 
 
201
  uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202
  uint32_t k_stride = p.nb11;
203
  uint32_t v_stride = p.nb21;
204
+ // When using grouped query attention, all rows use the same mask (stride 0).
205
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206
+ // that prevents the compiler from folding the "&" through the select
207
+ // and breaking the alignment detection.
208
+ uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
209
  // hint to the compiler that strides are aligned for the aligned variant of the shader
210
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
211
  {
 
214
  k_stride &= ~7;
215
  v_stride &= ~7;
216
  #endif
217
+ m_stride &= ~7;
218
  }
219
  tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
220
  tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
 
267
  if (p.mask != 0) {
268
  tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
269
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
270
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
 
 
271
 
272
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
273