Spaces:
Running
Running
Commit
·
825889e
1
Parent(s):
5db8b21
vulkan: use aligned loads for flash attention mask (llama/12853)
Browse filesRewrite the stride logic for the mask tensor in the FA shader to force the
stride to be aligned, to allow using more efficient loads.
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -201,6 +201,11 @@ void main() {
|
|
| 201 |
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 202 |
uint32_t k_stride = p.nb11;
|
| 203 |
uint32_t v_stride = p.nb21;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
| 205 |
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
| 206 |
{
|
|
@@ -209,6 +214,7 @@ void main() {
|
|
| 209 |
k_stride &= ~7;
|
| 210 |
v_stride &= ~7;
|
| 211 |
#endif
|
|
|
|
| 212 |
}
|
| 213 |
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
| 214 |
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
@@ -261,10 +267,7 @@ void main() {
|
|
| 261 |
if (p.mask != 0) {
|
| 262 |
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
| 263 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 264 |
-
|
| 265 |
-
if (p.gqa_ratio > 1) {
|
| 266 |
-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
| 267 |
-
}
|
| 268 |
|
| 269 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 270 |
|
|
|
|
| 201 |
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 202 |
uint32_t k_stride = p.nb11;
|
| 203 |
uint32_t v_stride = p.nb21;
|
| 204 |
+
// When using grouped query attention, all rows use the same mask (stride 0).
|
| 205 |
+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
| 206 |
+
// that prevents the compiler from folding the "&" through the select
|
| 207 |
+
// and breaking the alignment detection.
|
| 208 |
+
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
| 209 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
| 210 |
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
| 211 |
{
|
|
|
|
| 214 |
k_stride &= ~7;
|
| 215 |
v_stride &= ~7;
|
| 216 |
#endif
|
| 217 |
+
m_stride &= ~7;
|
| 218 |
}
|
| 219 |
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
| 220 |
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
|
|
| 267 |
if (p.mask != 0) {
|
| 268 |
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
| 269 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 270 |
+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 273 |
|