Spaces:
Running
Running
Commit
·
0ab9aba
1
Parent(s):
d8664e4
CUDA: attention sinks for mma FlashAttention (llama/15157)
Browse files
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 785 |
const half2 * const __restrict__ K_h2,
|
| 786 |
const half2 * const __restrict__ V_h2,
|
| 787 |
const half2 * const __restrict__ mask_h2,
|
|
|
|
| 788 |
float2 * const __restrict__ dstk,
|
| 789 |
float2 * const __restrict__ dstk_fixup,
|
| 790 |
const float scale,
|
|
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 957 |
}
|
| 958 |
}
|
| 959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
// Combine VKQ accumulator values if np > 1.
|
| 961 |
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 962 |
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1271 |
|
| 1272 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1273 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1274 |
-
const int
|
| 1275 |
-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*
|
| 1276 |
|
| 1277 |
-
const
|
| 1278 |
-
|
|
|
|
|
|
|
| 1279 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1280 |
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1281 |
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 +
|
| 1282 |
|
| 1283 |
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(
|
|
|
|
| 1284 |
|
| 1285 |
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
| 1286 |
|
| 1287 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1288 |
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1295 |
if (kb0_start == 0) {
|
| 1296 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 1297 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1298 |
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1299 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1300 |
} else {
|
| 1301 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 1302 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1303 |
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1304 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1305 |
}
|
| 1306 |
|
|
@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1316 |
}
|
| 1317 |
|
| 1318 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1319 |
-
const int
|
| 1320 |
-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*
|
|
|
|
|
|
|
| 1321 |
|
| 1322 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*
|
| 1323 |
-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(
|
| 1324 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1325 |
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1326 |
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 +
|
| 1327 |
|
| 1328 |
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(
|
|
|
|
| 1329 |
|
| 1330 |
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
| 1331 |
|
| 1332 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1333 |
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1339 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 1340 |
constexpr bool needs_fixup = false;
|
| 1341 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1342 |
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1343 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1344 |
#else
|
| 1345 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
|
|
|
| 785 |
const half2 * const __restrict__ K_h2,
|
| 786 |
const half2 * const __restrict__ V_h2,
|
| 787 |
const half2 * const __restrict__ mask_h2,
|
| 788 |
+
const float * const __restrict__ sinks_f,
|
| 789 |
float2 * const __restrict__ dstk,
|
| 790 |
float2 * const __restrict__ dstk_fixup,
|
| 791 |
const float scale,
|
|
|
|
| 958 |
}
|
| 959 |
}
|
| 960 |
|
| 961 |
+
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
| 962 |
+
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
| 963 |
+
// so it's being done unconditionally for every thread.
|
| 964 |
+
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
| 965 |
+
float KQ_max_scale[cols_per_thread];
|
| 966 |
+
#pragma unroll
|
| 967 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 968 |
+
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
| 969 |
+
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
| 970 |
+
const float sink = sinks_f[jc % ncols2];
|
| 971 |
+
|
| 972 |
+
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
| 973 |
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
| 974 |
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
| 975 |
+
KQ_max[col] = KQ_max_new;
|
| 976 |
+
|
| 977 |
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
| 978 |
+
|
| 979 |
+
const float KQ_max_add = expf(sink - KQ_max_new);
|
| 980 |
+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
if (ntiles == 1) {
|
| 984 |
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
| 985 |
+
#pragma unroll
|
| 986 |
+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
| 987 |
+
#pragma unroll
|
| 988 |
+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
| 989 |
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
| 990 |
+
}
|
| 991 |
+
}
|
| 992 |
+
} else {
|
| 993 |
+
#pragma unroll
|
| 994 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 995 |
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
| 996 |
+
#pragma unroll
|
| 997 |
+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
| 998 |
+
#pragma unroll
|
| 999 |
+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
| 1000 |
+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
| 1001 |
+
}
|
| 1002 |
+
}
|
| 1003 |
+
}
|
| 1004 |
+
}
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
// Combine VKQ accumulator values if np > 1.
|
| 1008 |
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 1009 |
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
|
|
| 1318 |
|
| 1319 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1320 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1321 |
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
| 1322 |
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
| 1323 |
|
| 1324 |
+
const int head0 = zt * ncols2;
|
| 1325 |
+
|
| 1326 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
| 1327 |
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
| 1328 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1329 |
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1330 |
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
|
| 1331 |
|
| 1332 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
| 1333 |
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
| 1334 |
|
| 1335 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
| 1336 |
|
| 1337 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1338 |
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
|
| 1345 |
if (kb0_start == 0) {
|
| 1346 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 1347 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1348 |
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1349 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1350 |
} else {
|
| 1351 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 1352 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1353 |
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1354 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1355 |
}
|
| 1356 |
|
|
|
|
| 1366 |
}
|
| 1367 |
|
| 1368 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1369 |
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
| 1370 |
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
| 1371 |
+
|
| 1372 |
+
const int head0 = zt * ncols2;
|
| 1373 |
|
| 1374 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
| 1375 |
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
| 1376 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1377 |
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1378 |
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
|
| 1379 |
|
| 1380 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
| 1381 |
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
| 1382 |
|
| 1383 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
| 1384 |
|
| 1385 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1386 |
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
|
| 1392 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 1393 |
constexpr bool needs_fixup = false;
|
| 1394 |
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1395 |
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1396 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1397 |
#else
|
| 1398 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 282 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 283 |
|
| 284 |
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
|
| 285 |
-
if (sinks) {
|
| 286 |
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 287 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 288 |
} else {
|
|
|
|
| 282 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 283 |
|
| 284 |
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
|
| 285 |
+
if (sinks && !fp16_mma_available(cc)) {
|
| 286 |
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 287 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 288 |
} else {
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3532 |
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
| 3533 |
}
|
| 3534 |
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
| 3535 |
-
if (op->src[4] &&
|
|
|
|
| 3536 |
return false;
|
| 3537 |
}
|
| 3538 |
if (op->src[0]->ne[0] == 192) {
|
|
|
|
| 3532 |
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
| 3533 |
}
|
| 3534 |
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
| 3535 |
+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
|
| 3536 |
+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
|
| 3537 |
return false;
|
| 3538 |
}
|
| 3539 |
if (op->src[0]->ne[0] == 192) {
|