JohannesGaessler commited on
Commit
0ab9aba
·
1 Parent(s): d8664e4

CUDA: attention sinks for mma FlashAttention (llama/15157)

Browse files
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
785
  const half2 * const __restrict__ K_h2,
786
  const half2 * const __restrict__ V_h2,
787
  const half2 * const __restrict__ mask_h2,
 
788
  float2 * const __restrict__ dstk,
789
  float2 * const __restrict__ dstk_fixup,
790
  const float scale,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
957
  }
958
  }
959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  // Combine VKQ accumulator values if np > 1.
961
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
962
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
1271
 
1272
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1273
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1274
- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1275
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1276
 
1277
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1278
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
 
 
1279
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1280
  (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1281
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1282
 
1283
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
1284
 
1285
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1286
 
1287
  const int kb0_start_kernel = kb0_start * kb_niter;
1288
  int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
1295
  if (kb0_start == 0) {
1296
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1297
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1298
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1299
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1300
  } else {
1301
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1302
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1303
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1304
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1305
  }
1306
 
@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
1316
  }
1317
 
1318
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1319
- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1320
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
 
1321
 
1322
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1323
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1324
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1325
  (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1326
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1327
 
1328
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
 
1329
 
1330
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1331
 
1332
  const int kb0_start_kernel = kb0_start * kb_niter;
1333
  int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
1339
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1340
  constexpr bool needs_fixup = false;
1341
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1342
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1343
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1344
  #else
1345
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
 
785
  const half2 * const __restrict__ K_h2,
786
  const half2 * const __restrict__ V_h2,
787
  const half2 * const __restrict__ mask_h2,
788
+ const float * const __restrict__ sinks_f,
789
  float2 * const __restrict__ dstk,
790
  float2 * const __restrict__ dstk_fixup,
791
  const float scale,
 
958
  }
959
  }
960
 
961
+ // If attention sinks are used, potentially re-scale if KQ_max is small.
962
+ // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
963
+ // so it's being done unconditionally for every thread.
964
+ if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
965
+ float KQ_max_scale[cols_per_thread];
966
+ #pragma unroll
967
+ for (int col = 0; col < cols_per_thread; ++col) {
968
+ static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
969
+ const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
970
+ const float sink = sinks_f[jc % ncols2];
971
+
972
+ const float KQ_max_new = fmaxf(KQ_max[col], sink);
973
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new;
974
+ KQ_max_scale[col] = expf(KQ_max_diff);
975
+ KQ_max[col] = KQ_max_new;
976
+
977
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
978
+
979
+ const float KQ_max_add = expf(sink - KQ_max_new);
980
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
981
+ }
982
+
983
+ if (ntiles == 1) {
984
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
985
+ #pragma unroll
986
+ for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
987
+ #pragma unroll
988
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
989
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
990
+ }
991
+ }
992
+ } else {
993
+ #pragma unroll
994
+ for (int col = 0; col < cols_per_thread; ++col) {
995
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
996
+ #pragma unroll
997
+ for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
998
+ #pragma unroll
999
+ for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
1000
+ VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
1001
+ }
1002
+ }
1003
+ }
1004
+ }
1005
+ }
1006
+
1007
  // Combine VKQ accumulator values if np > 1.
1008
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
1009
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
 
1318
 
1319
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1320
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1321
+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1322
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1323
 
1324
+ const int head0 = zt * ncols2;
1325
+
1326
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1327
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1328
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1329
  (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1330
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1331
 
1332
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1333
+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1334
 
1335
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1336
 
1337
  const int kb0_start_kernel = kb0_start * kb_niter;
1338
  int kb0_stop_kernel = kb0_stop * kb_niter;
 
1345
  if (kb0_start == 0) {
1346
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1347
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1348
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1349
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1350
  } else {
1351
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1352
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1353
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1354
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1355
  }
1356
 
 
1366
  }
1367
 
1368
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1369
+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1370
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1371
+
1372
+ const int head0 = zt * ncols2;
1373
 
1374
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1375
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1376
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1377
  (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1378
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1379
 
1380
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1381
+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1382
 
1383
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1384
 
1385
  const int kb0_start_kernel = kb0_start * kb_niter;
1386
  int kb0_stop_kernel = kb0_stop * kb_niter;
 
1392
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1393
  constexpr bool needs_fixup = false;
1394
  flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1395
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1396
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1397
  #else
1398
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
282
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
283
 
284
  // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285
- if (sinks) {
286
  if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288
  } else {
 
282
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
283
 
284
  // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285
+ if (sinks && !fp16_mma_available(cc)) {
286
  if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288
  } else {
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3532
  return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3533
  }
3534
  // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3535
- if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
 
3536
  return false;
3537
  }
3538
  if (op->src[0]->ne[0] == 192) {
 
3532
  return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3533
  }
3534
  // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3535
+ if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
3536
+ && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
3537
  return false;
3538
  }
3539
  if (op->src[0]->ne[0] == 192) {