jeffbolznv commited on
Commit
e0e73fa
·
1 Parent(s): 4134077

vulkan: fix coopmat2 flash attention for non-contiguous inputs (llama/11281)

Browse files

Add code similar to mul_mm_cm2 to force alignment of strides, to avoid
a performance regression.

Add noncontiguous FA tests in test-backend-ops.

Fixes #11268.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants {
386
  uint32_t nev3;
387
  uint32_t nem1;
388
 
 
389
  uint32_t nb02;
390
  uint32_t nb03;
 
391
  uint32_t nb12;
392
  uint32_t nb13;
 
393
  uint32_t nb22;
394
  uint32_t nb23;
395
  uint32_t nb31;
@@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4809
  }
4810
  assert(pipelines);
4811
 
4812
- bool aligned = (KV % pipelines[1]->align) == 0;
 
 
 
 
 
 
 
4813
  vk_pipeline pipeline = pipelines[aligned];
4814
  assert(pipeline);
4815
 
@@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4845
 
4846
  if (ctx->device->uma) {
4847
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4848
- ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4849
- ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4850
- ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4851
  Q_uma = d_Q != nullptr;
4852
  K_uma = d_K != nullptr;
4853
  V_uma = d_V != nullptr;
4854
  D_uma = d_D != nullptr;
4855
  if (mask) {
4856
- ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4857
  M_uma = d_M != nullptr;
4858
  }
4859
  }
@@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4891
  }
4892
  }
4893
 
4894
- const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
 
 
 
 
 
 
 
 
 
 
 
4895
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4896
  {
4897
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8668
  ggml_tensor * src0 = tensor->src[0];
8669
  ggml_tensor * src1 = tensor->src[1];
8670
  ggml_tensor * src2 = tensor->src[2];
 
8671
 
8672
  void * tensor_data = tensor->data;
8673
 
@@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8730
  if (src2 != nullptr) {
8731
  std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8732
  }
 
 
 
8733
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8734
  std::cerr << std::endl << "Result:" << std::endl;
8735
  ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
@@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8774
  if (src2 != nullptr) {
8775
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8776
  }
 
 
 
8777
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8778
  std::cerr << std::endl << "Result:" << std::endl;
8779
  ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
@@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8796
  if (src2 != nullptr) {
8797
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8798
  }
 
 
 
8799
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8800
  std::cerr << std::endl << "Result:" << std::endl;
8801
  ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
 
386
  uint32_t nev3;
387
  uint32_t nem1;
388
 
389
+ uint32_t nb01;
390
  uint32_t nb02;
391
  uint32_t nb03;
392
+ uint32_t nb11;
393
  uint32_t nb12;
394
  uint32_t nb13;
395
+ uint32_t nb21;
396
  uint32_t nb22;
397
  uint32_t nb23;
398
  uint32_t nb31;
 
4812
  }
4813
  assert(pipelines);
4814
 
4815
+ const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
4816
+ const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
4817
+ const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
4818
+
4819
+ bool aligned = (KV % pipelines[1]->align) == 0 &&
4820
+ // the "aligned" shader variant will forcibly align strides, for performance
4821
+ (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
4822
+
4823
  vk_pipeline pipeline = pipelines[aligned];
4824
  assert(pipeline);
4825
 
 
4855
 
4856
  if (ctx->device->uma) {
4857
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4858
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
4859
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
4860
+ ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
4861
  Q_uma = d_Q != nullptr;
4862
  K_uma = d_K != nullptr;
4863
  V_uma = d_V != nullptr;
4864
  D_uma = d_D != nullptr;
4865
  if (mask) {
4866
+ ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
4867
  M_uma = d_M != nullptr;
4868
  }
4869
  }
 
4901
  }
4902
  }
4903
 
4904
+ const vk_flash_attn_push_constants pc = { N, KV,
4905
+ (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
4906
+ (uint32_t)neq2, (uint32_t)neq3,
4907
+ (uint32_t)nek2, (uint32_t)nek3,
4908
+ (uint32_t)nev2, (uint32_t)nev3,
4909
+ nem1,
4910
+ q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
4911
+ k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
4912
+ v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
4913
+ nbm1,
4914
+ scale, max_bias, logit_softcap,
4915
+ mask != nullptr, n_head_log2, m0, m1 };
4916
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4917
  {
4918
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
 
8689
  ggml_tensor * src0 = tensor->src[0];
8690
  ggml_tensor * src1 = tensor->src[1];
8691
  ggml_tensor * src2 = tensor->src[2];
8692
+ ggml_tensor * src3 = tensor->src[3];
8693
 
8694
  void * tensor_data = tensor->data;
8695
 
 
8752
  if (src2 != nullptr) {
8753
  std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8754
  }
8755
+ if (src3 != nullptr) {
8756
+ std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8757
+ }
8758
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8759
  std::cerr << std::endl << "Result:" << std::endl;
8760
  ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
 
8799
  if (src2 != nullptr) {
8800
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8801
  }
8802
+ if (src3 != nullptr) {
8803
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8804
+ }
8805
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8806
  std::cerr << std::endl << "Result:" << std::endl;
8807
  ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
 
8824
  if (src2 != nullptr) {
8825
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8826
  }
8827
+ if (src3 != nullptr) {
8828
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8829
+ }
8830
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8831
  std::cerr << std::endl << "Result:" << std::endl;
8832
  ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
42
  uint32_t nev3;
43
  uint32_t nem1;
44
 
 
45
  uint32_t nb02;
46
  uint32_t nb03;
 
47
  uint32_t nb12;
48
  uint32_t nb13;
 
49
  uint32_t nb22;
50
  uint32_t nb23;
51
  uint32_t nb31;
@@ -146,6 +149,23 @@ void main() {
146
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
147
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
150
  coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
151
 
 
42
  uint32_t nev3;
43
  uint32_t nem1;
44
 
45
+ uint32_t nb01;
46
  uint32_t nb02;
47
  uint32_t nb03;
48
+ uint32_t nb11;
49
  uint32_t nb12;
50
  uint32_t nb13;
51
+ uint32_t nb21;
52
  uint32_t nb22;
53
  uint32_t nb23;
54
  uint32_t nb31;
 
149
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
150
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
151
 
152
+ // nb?1 are already divided by the type size and are in units of elements
153
+ uint32_t q_stride = p.nb01;
154
+ uint32_t k_stride = p.nb11;
155
+ uint32_t v_stride = p.nb21;
156
+ // hint to the compiler that strides are aligned for the aligned variant of the shader
157
+ if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
158
+ {
159
+ q_stride &= ~7;
160
+ #if !defined(BLOCK_SIZE)
161
+ k_stride &= ~7;
162
+ v_stride &= ~7;
163
+ #endif
164
+ }
165
+ tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
166
+ tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
167
+ tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
168
+
169
  coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
170
  coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
171