Spaces:
Running
Running
Commit
·
e0e73fa
1
Parent(s):
4134077
vulkan: fix coopmat2 flash attention for non-contiguous inputs (llama/11281)
Browse filesAdd code similar to mul_mm_cm2 to force alignment of strides, to avoid
a performance regression.
Add noncontiguous FA tests in test-backend-ops.
Fixes #11268.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants {
|
|
| 386 |
uint32_t nev3;
|
| 387 |
uint32_t nem1;
|
| 388 |
|
|
|
|
| 389 |
uint32_t nb02;
|
| 390 |
uint32_t nb03;
|
|
|
|
| 391 |
uint32_t nb12;
|
| 392 |
uint32_t nb13;
|
|
|
|
| 393 |
uint32_t nb22;
|
| 394 |
uint32_t nb23;
|
| 395 |
uint32_t nb31;
|
|
@@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 4809 |
}
|
| 4810 |
assert(pipelines);
|
| 4811 |
|
| 4812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4813 |
vk_pipeline pipeline = pipelines[aligned];
|
| 4814 |
assert(pipeline);
|
| 4815 |
|
|
@@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 4845 |
|
| 4846 |
if (ctx->device->uma) {
|
| 4847 |
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
| 4848 |
-
ggml_vk_host_get(ctx->device, k->data, d_K,
|
| 4849 |
-
ggml_vk_host_get(ctx->device, v->data, d_V,
|
| 4850 |
-
ggml_vk_host_get(ctx->device, dst->data, d_D,
|
| 4851 |
Q_uma = d_Q != nullptr;
|
| 4852 |
K_uma = d_K != nullptr;
|
| 4853 |
V_uma = d_V != nullptr;
|
| 4854 |
D_uma = d_D != nullptr;
|
| 4855 |
if (mask) {
|
| 4856 |
-
ggml_vk_host_get(ctx->device, mask->data, d_M,
|
| 4857 |
M_uma = d_M != nullptr;
|
| 4858 |
}
|
| 4859 |
}
|
|
@@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 4891 |
}
|
| 4892 |
}
|
| 4893 |
|
| 4894 |
-
const vk_flash_attn_push_constants pc = { N, KV,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4895 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 4896 |
{
|
| 4897 |
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
@@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
| 8668 |
ggml_tensor * src0 = tensor->src[0];
|
| 8669 |
ggml_tensor * src1 = tensor->src[1];
|
| 8670 |
ggml_tensor * src2 = tensor->src[2];
|
|
|
|
| 8671 |
|
| 8672 |
void * tensor_data = tensor->data;
|
| 8673 |
|
|
@@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
| 8730 |
if (src2 != nullptr) {
|
| 8731 |
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8732 |
}
|
|
|
|
|
|
|
|
|
|
| 8733 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8734 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8735 |
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
|
|
@@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
| 8774 |
if (src2 != nullptr) {
|
| 8775 |
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8776 |
}
|
|
|
|
|
|
|
|
|
|
| 8777 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8778 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8779 |
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
|
|
@@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
| 8796 |
if (src2 != nullptr) {
|
| 8797 |
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8798 |
}
|
|
|
|
|
|
|
|
|
|
| 8799 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8800 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8801 |
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
|
|
|
|
| 386 |
uint32_t nev3;
|
| 387 |
uint32_t nem1;
|
| 388 |
|
| 389 |
+
uint32_t nb01;
|
| 390 |
uint32_t nb02;
|
| 391 |
uint32_t nb03;
|
| 392 |
+
uint32_t nb11;
|
| 393 |
uint32_t nb12;
|
| 394 |
uint32_t nb13;
|
| 395 |
+
uint32_t nb21;
|
| 396 |
uint32_t nb22;
|
| 397 |
uint32_t nb23;
|
| 398 |
uint32_t nb31;
|
|
|
|
| 4812 |
}
|
| 4813 |
assert(pipelines);
|
| 4814 |
|
| 4815 |
+
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
| 4816 |
+
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
| 4817 |
+
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
| 4818 |
+
|
| 4819 |
+
bool aligned = (KV % pipelines[1]->align) == 0 &&
|
| 4820 |
+
// the "aligned" shader variant will forcibly align strides, for performance
|
| 4821 |
+
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
| 4822 |
+
|
| 4823 |
vk_pipeline pipeline = pipelines[aligned];
|
| 4824 |
assert(pipeline);
|
| 4825 |
|
|
|
|
| 4855 |
|
| 4856 |
if (ctx->device->uma) {
|
| 4857 |
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
| 4858 |
+
ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
|
| 4859 |
+
ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
|
| 4860 |
+
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
|
| 4861 |
Q_uma = d_Q != nullptr;
|
| 4862 |
K_uma = d_K != nullptr;
|
| 4863 |
V_uma = d_V != nullptr;
|
| 4864 |
D_uma = d_D != nullptr;
|
| 4865 |
if (mask) {
|
| 4866 |
+
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
| 4867 |
M_uma = d_M != nullptr;
|
| 4868 |
}
|
| 4869 |
}
|
|
|
|
| 4901 |
}
|
| 4902 |
}
|
| 4903 |
|
| 4904 |
+
const vk_flash_attn_push_constants pc = { N, KV,
|
| 4905 |
+
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
| 4906 |
+
(uint32_t)neq2, (uint32_t)neq3,
|
| 4907 |
+
(uint32_t)nek2, (uint32_t)nek3,
|
| 4908 |
+
(uint32_t)nev2, (uint32_t)nev3,
|
| 4909 |
+
nem1,
|
| 4910 |
+
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
| 4911 |
+
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
| 4912 |
+
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 4913 |
+
nbm1,
|
| 4914 |
+
scale, max_bias, logit_softcap,
|
| 4915 |
+
mask != nullptr, n_head_log2, m0, m1 };
|
| 4916 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 4917 |
{
|
| 4918 |
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
|
|
| 8689 |
ggml_tensor * src0 = tensor->src[0];
|
| 8690 |
ggml_tensor * src1 = tensor->src[1];
|
| 8691 |
ggml_tensor * src2 = tensor->src[2];
|
| 8692 |
+
ggml_tensor * src3 = tensor->src[3];
|
| 8693 |
|
| 8694 |
void * tensor_data = tensor->data;
|
| 8695 |
|
|
|
|
| 8752 |
if (src2 != nullptr) {
|
| 8753 |
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8754 |
}
|
| 8755 |
+
if (src3 != nullptr) {
|
| 8756 |
+
std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
| 8757 |
+
}
|
| 8758 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8759 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8760 |
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
|
|
|
|
| 8799 |
if (src2 != nullptr) {
|
| 8800 |
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8801 |
}
|
| 8802 |
+
if (src3 != nullptr) {
|
| 8803 |
+
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
| 8804 |
+
}
|
| 8805 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8806 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8807 |
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
|
|
|
|
| 8824 |
if (src2 != nullptr) {
|
| 8825 |
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
| 8826 |
}
|
| 8827 |
+
if (src3 != nullptr) {
|
| 8828 |
+
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
| 8829 |
+
}
|
| 8830 |
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
| 8831 |
std::cerr << std::endl << "Result:" << std::endl;
|
| 8832 |
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
|
|
| 42 |
uint32_t nev3;
|
| 43 |
uint32_t nem1;
|
| 44 |
|
|
|
|
| 45 |
uint32_t nb02;
|
| 46 |
uint32_t nb03;
|
|
|
|
| 47 |
uint32_t nb12;
|
| 48 |
uint32_t nb13;
|
|
|
|
| 49 |
uint32_t nb22;
|
| 50 |
uint32_t nb23;
|
| 51 |
uint32_t nb31;
|
|
@@ -146,6 +149,23 @@ void main() {
|
|
| 146 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 147 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
| 150 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
| 151 |
|
|
|
|
| 42 |
uint32_t nev3;
|
| 43 |
uint32_t nem1;
|
| 44 |
|
| 45 |
+
uint32_t nb01;
|
| 46 |
uint32_t nb02;
|
| 47 |
uint32_t nb03;
|
| 48 |
+
uint32_t nb11;
|
| 49 |
uint32_t nb12;
|
| 50 |
uint32_t nb13;
|
| 51 |
+
uint32_t nb21;
|
| 52 |
uint32_t nb22;
|
| 53 |
uint32_t nb23;
|
| 54 |
uint32_t nb31;
|
|
|
|
| 149 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 150 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 151 |
|
| 152 |
+
// nb?1 are already divided by the type size and are in units of elements
|
| 153 |
+
uint32_t q_stride = p.nb01;
|
| 154 |
+
uint32_t k_stride = p.nb11;
|
| 155 |
+
uint32_t v_stride = p.nb21;
|
| 156 |
+
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
| 157 |
+
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
| 158 |
+
{
|
| 159 |
+
q_stride &= ~7;
|
| 160 |
+
#if !defined(BLOCK_SIZE)
|
| 161 |
+
k_stride &= ~7;
|
| 162 |
+
v_stride &= ~7;
|
| 163 |
+
#endif
|
| 164 |
+
}
|
| 165 |
+
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
| 166 |
+
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
| 167 |
+
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
| 168 |
+
|
| 169 |
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
| 170 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
| 171 |
|