Spaces:
Running
Running
OpenCL: add initial FA support (llama/14987)
Browse files* add F16/F16 fa support
* fix kernel init
* use mad instead of fma
* use inline function
* mark FA with sinks as unsupported for now
* add pragma unroll to loops
ggml/src/ggml-opencl/CMakeLists.txt
CHANGED
|
@@ -112,6 +112,9 @@ set(GGML_OPENCL_KERNELS
|
|
| 112 |
mul_mat_f16_f32
|
| 113 |
conv2d
|
| 114 |
conv2d_f16_f32
|
|
|
|
|
|
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
foreach (K ${GGML_OPENCL_KERNELS})
|
|
|
|
| 112 |
mul_mat_f16_f32
|
| 113 |
conv2d
|
| 114 |
conv2d_f16_f32
|
| 115 |
+
flash_attn_f32_f16
|
| 116 |
+
flash_attn_f16
|
| 117 |
+
flash_attn_f32
|
| 118 |
)
|
| 119 |
|
| 120 |
foreach (K ${GGML_OPENCL_KERNELS})
|
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -25,6 +25,7 @@
|
|
| 25 |
#include <vector>
|
| 26 |
#include <string>
|
| 27 |
#include <cmath>
|
|
|
|
| 28 |
#include <memory>
|
| 29 |
#include <charconv>
|
| 30 |
#include <mutex>
|
|
@@ -424,6 +425,14 @@ struct ggml_backend_opencl_context {
|
|
| 424 |
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
| 425 |
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
| 426 |
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
| 428 |
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
|
| 429 |
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
|
@@ -1308,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|
| 1308 |
GGML_LOG_CONT(".");
|
| 1309 |
}
|
| 1310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1311 |
// argsort
|
| 1312 |
{
|
| 1313 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
@@ -2636,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2636 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2637 |
case GGML_OP_SUM_ROWS:
|
| 2638 |
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2639 |
default:
|
| 2640 |
return false;
|
| 2641 |
}
|
|
@@ -5451,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
|
| 5451 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
|
| 5452 |
}
|
| 5453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5454 |
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 5455 |
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 5456 |
|
|
@@ -7607,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
|
| 7607 |
}
|
| 7608 |
func = ggml_cl_sum_rows;
|
| 7609 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7610 |
default:
|
| 7611 |
return false;
|
| 7612 |
}
|
|
|
|
| 25 |
#include <vector>
|
| 26 |
#include <string>
|
| 27 |
#include <cmath>
|
| 28 |
+
#include <map>
|
| 29 |
#include <memory>
|
| 30 |
#include <charconv>
|
| 31 |
#include <mutex>
|
|
|
|
| 425 |
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
| 426 |
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
| 427 |
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
| 428 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
|
| 429 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
|
| 430 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
|
| 431 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
|
| 432 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
|
| 433 |
+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
|
| 434 |
+
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
|
| 435 |
+
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
|
| 436 |
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
| 437 |
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
|
| 438 |
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
|
|
|
| 1317 |
GGML_LOG_CONT(".");
|
| 1318 |
}
|
| 1319 |
|
| 1320 |
+
// flash_attn
|
| 1321 |
+
{
|
| 1322 |
+
#ifdef GGML_OPENCL_EMBED_KERNELS
|
| 1323 |
+
const std::string kernel_src_f16 {
|
| 1324 |
+
#include "flash_attn_f16.cl.h"
|
| 1325 |
+
};
|
| 1326 |
+
const std::string kernel_src_f32 {
|
| 1327 |
+
#include "flash_attn_f32.cl.h"
|
| 1328 |
+
};
|
| 1329 |
+
const std::string kernel_src_f32_f16 {
|
| 1330 |
+
#include "flash_attn_f32_f16.cl.h"
|
| 1331 |
+
};
|
| 1332 |
+
#else
|
| 1333 |
+
const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
|
| 1334 |
+
const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
|
| 1335 |
+
const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
|
| 1336 |
+
#endif
|
| 1337 |
+
|
| 1338 |
+
if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
|
| 1339 |
+
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
|
| 1340 |
+
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
|
| 1341 |
+
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
|
| 1342 |
+
{192, 192, 16, 16}, {256, 256, 16, 16},
|
| 1343 |
+
};
|
| 1344 |
+
|
| 1345 |
+
for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
|
| 1346 |
+
const int dk = fa_dims[i].dk;
|
| 1347 |
+
const int dv = fa_dims[i].dv;
|
| 1348 |
+
const int bm = fa_dims[i].bm;
|
| 1349 |
+
const int bn = fa_dims[i].bn;
|
| 1350 |
+
std::string OPTS = compile_opts +
|
| 1351 |
+
" -D DK=" + std::to_string(dk) +
|
| 1352 |
+
" -D DV=" + std::to_string(dv) +
|
| 1353 |
+
" -D BLOCK_M=" + std::to_string(bm) +
|
| 1354 |
+
" -D BLOCK_N=" + std::to_string(bn);
|
| 1355 |
+
|
| 1356 |
+
cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
|
| 1357 |
+
cl_kernel k_f16, k_f16_q1;
|
| 1358 |
+
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
|
| 1359 |
+
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
|
| 1360 |
+
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
|
| 1361 |
+
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
|
| 1362 |
+
CL_CHECK(clReleaseProgram(prog_f16));
|
| 1363 |
+
|
| 1364 |
+
cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
|
| 1365 |
+
cl_kernel k_f32, k_f32_q1;
|
| 1366 |
+
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
|
| 1367 |
+
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
|
| 1368 |
+
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
|
| 1369 |
+
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
|
| 1370 |
+
CL_CHECK(clReleaseProgram(prog_f32));
|
| 1371 |
+
|
| 1372 |
+
cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
|
| 1373 |
+
cl_kernel k_f32_f16, k_f32_f16_q1;
|
| 1374 |
+
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
|
| 1375 |
+
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
|
| 1376 |
+
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
|
| 1377 |
+
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
|
| 1378 |
+
CL_CHECK(clReleaseProgram(prog_f32_f16));
|
| 1379 |
+
|
| 1380 |
+
backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
|
| 1381 |
+
backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
|
| 1382 |
+
}
|
| 1383 |
+
GGML_LOG_CONT(".");
|
| 1384 |
+
}
|
| 1385 |
+
}
|
| 1386 |
+
|
| 1387 |
// argsort
|
| 1388 |
{
|
| 1389 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
|
|
| 2712 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2713 |
case GGML_OP_SUM_ROWS:
|
| 2714 |
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
| 2715 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 2716 |
+
{
|
| 2717 |
+
if (op->src[4]) {
|
| 2718 |
+
return false;
|
| 2719 |
+
}
|
| 2720 |
+
|
| 2721 |
+
const ggml_tensor * q = op->src[0];
|
| 2722 |
+
const ggml_tensor * k = op->src[1];
|
| 2723 |
+
const ggml_tensor * v = op->src[2];
|
| 2724 |
+
|
| 2725 |
+
const int dk = q->ne[0];
|
| 2726 |
+
const int dv = v->ne[0];
|
| 2727 |
+
|
| 2728 |
+
const struct { int dk; int dv; } supported_dims[] = {
|
| 2729 |
+
{ 64, 64}, { 80, 80}, { 96, 96},
|
| 2730 |
+
{112, 112}, {128, 128}, {192, 128},
|
| 2731 |
+
{192, 192}, {256, 256},
|
| 2732 |
+
};
|
| 2733 |
+
|
| 2734 |
+
bool dims_supported = false;
|
| 2735 |
+
for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
|
| 2736 |
+
if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
|
| 2737 |
+
dims_supported = true;
|
| 2738 |
+
break;
|
| 2739 |
+
}
|
| 2740 |
+
}
|
| 2741 |
+
if (!dims_supported) {
|
| 2742 |
+
return false;
|
| 2743 |
+
}
|
| 2744 |
+
|
| 2745 |
+
const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
|
| 2746 |
+
v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
| 2747 |
+
const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
|
| 2748 |
+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
|
| 2749 |
+
const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
|
| 2750 |
+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
|
| 2751 |
+
|
| 2752 |
+
return is_f32_f32 || is_f16_f16 || is_f32_f16;
|
| 2753 |
+
}
|
| 2754 |
default:
|
| 2755 |
return false;
|
| 2756 |
}
|
|
|
|
| 5566 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
|
| 5567 |
}
|
| 5568 |
|
| 5569 |
+
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
|
| 5570 |
+
const ggml_tensor * v = dst->src[2];
|
| 5571 |
+
const ggml_tensor * mask = dst->src[3];
|
| 5572 |
+
GGML_ASSERT(q->extra);
|
| 5573 |
+
GGML_ASSERT(k->extra);
|
| 5574 |
+
GGML_ASSERT(v->extra);
|
| 5575 |
+
GGML_ASSERT(dst->extra);
|
| 5576 |
+
if (mask) {
|
| 5577 |
+
GGML_ASSERT(mask->extra);
|
| 5578 |
+
}
|
| 5579 |
+
|
| 5580 |
+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 5581 |
+
|
| 5582 |
+
const int n_q = q->ne[1];
|
| 5583 |
+
const int n_kv = k->ne[1];
|
| 5584 |
+
const int d_head_q = q->ne[0];
|
| 5585 |
+
const int d_head_v = v->ne[0];
|
| 5586 |
+
const int n_head = q->ne[2];
|
| 5587 |
+
const int n_head_kv = k->ne[2];
|
| 5588 |
+
const int n_batch = q->ne[3];
|
| 5589 |
+
|
| 5590 |
+
cl_kernel kernel = NULL;
|
| 5591 |
+
|
| 5592 |
+
const bool is_f16 = q->type == GGML_TYPE_F16;
|
| 5593 |
+
const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
|
| 5594 |
+
const std::pair<int, int> dk_dv = {d_head_q, d_head_v};
|
| 5595 |
+
|
| 5596 |
+
if (n_q == 1) {
|
| 5597 |
+
if (is_mixed) {
|
| 5598 |
+
kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
|
| 5599 |
+
} else if (is_f16) {
|
| 5600 |
+
kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
|
| 5601 |
+
} else {
|
| 5602 |
+
kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
|
| 5603 |
+
}
|
| 5604 |
+
} else {
|
| 5605 |
+
if (is_mixed) {
|
| 5606 |
+
kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
|
| 5607 |
+
} else if (is_f16) {
|
| 5608 |
+
kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
|
| 5609 |
+
} else {
|
| 5610 |
+
kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
|
| 5611 |
+
}
|
| 5612 |
+
}
|
| 5613 |
+
GGML_ASSERT(kernel != NULL);
|
| 5614 |
+
|
| 5615 |
+
ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
|
| 5616 |
+
ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
|
| 5617 |
+
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
|
| 5618 |
+
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
|
| 5619 |
+
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
|
| 5620 |
+
|
| 5621 |
+
cl_ulong offset_q = extra_q->offset + q->view_offs;
|
| 5622 |
+
cl_ulong offset_k = extra_k->offset + k->view_offs;
|
| 5623 |
+
cl_ulong offset_v = extra_v->offset + v->view_offs;
|
| 5624 |
+
cl_ulong offset_o = extra_o->offset + dst->view_offs;
|
| 5625 |
+
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
|
| 5626 |
+
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
|
| 5627 |
+
|
| 5628 |
+
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
|
| 5629 |
+
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
|
| 5630 |
+
const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
|
| 5631 |
+
const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
|
| 5632 |
+
const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
|
| 5633 |
+
const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
|
| 5634 |
+
const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
|
| 5635 |
+
const int mask_ne2 = mask ? mask->ne[2] : 0;
|
| 5636 |
+
const int mask_ne3 = mask ? mask->ne[3] : 0;
|
| 5637 |
+
|
| 5638 |
+
float scale, max_bias, logit_softcap;
|
| 5639 |
+
const float * params = (const float *)dst->op_params;
|
| 5640 |
+
scale = params[0];
|
| 5641 |
+
max_bias = params[1];
|
| 5642 |
+
logit_softcap = params[2];
|
| 5643 |
+
|
| 5644 |
+
const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
|
| 5645 |
+
|
| 5646 |
+
const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
|
| 5647 |
+
const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
|
| 5648 |
+
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
|
| 5649 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
|
| 5650 |
+
|
| 5651 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
|
| 5652 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
|
| 5653 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
|
| 5654 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
|
| 5655 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
|
| 5656 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
|
| 5657 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
|
| 5658 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
|
| 5659 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
|
| 5660 |
+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
|
| 5661 |
+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
|
| 5662 |
+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
|
| 5663 |
+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
|
| 5664 |
+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
|
| 5665 |
+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
|
| 5666 |
+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
|
| 5667 |
+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
|
| 5668 |
+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
|
| 5669 |
+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
|
| 5670 |
+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
|
| 5671 |
+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
|
| 5672 |
+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
|
| 5673 |
+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
|
| 5674 |
+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
|
| 5675 |
+
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
|
| 5676 |
+
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
|
| 5677 |
+
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
|
| 5678 |
+
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
|
| 5679 |
+
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
|
| 5680 |
+
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
|
| 5681 |
+
|
| 5682 |
+
if (n_q == 1) {
|
| 5683 |
+
const size_t wg_size = 64;
|
| 5684 |
+
size_t local_work_size[] = { wg_size, 1 };
|
| 5685 |
+
size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
|
| 5686 |
+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
| 5687 |
+
} else {
|
| 5688 |
+
const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
|
| 5689 |
+
const size_t wg_size = block_m;
|
| 5690 |
+
size_t local_work_size[] = { wg_size, 1 };
|
| 5691 |
+
size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
|
| 5692 |
+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
| 5693 |
+
}
|
| 5694 |
+
}
|
| 5695 |
+
|
| 5696 |
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 5697 |
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 5698 |
|
|
|
|
| 7849 |
}
|
| 7850 |
func = ggml_cl_sum_rows;
|
| 7851 |
break;
|
| 7852 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 7853 |
+
if (!any_on_device) {
|
| 7854 |
+
return false;
|
| 7855 |
+
}
|
| 7856 |
+
ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
|
| 7857 |
+
return true;
|
| 7858 |
default:
|
| 7859 |
return false;
|
| 7860 |
}
|
ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
+
|
| 3 |
+
#define ACC_TYPE float
|
| 4 |
+
#define ACC_TYPE4 float4
|
| 5 |
+
#define DATA_TYPE half
|
| 6 |
+
#define DATA_TYPE4 half4
|
| 7 |
+
#define CONVERT_ACC4(x) convert_float4(x)
|
| 8 |
+
#define CONVERT_DATA4(x) convert_half4(x)
|
| 9 |
+
|
| 10 |
+
#define DK_VEC (DK/4)
|
| 11 |
+
#define DV_VEC (DV/4)
|
| 12 |
+
#define WG_SIZE (BLOCK_M)
|
| 13 |
+
#define Q1_WG_SIZE 64
|
| 14 |
+
|
| 15 |
+
inline float get_alibi_slope(
|
| 16 |
+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
| 17 |
+
) {
|
| 18 |
+
if (max_bias <= 0.0f) {
|
| 19 |
+
return 1.0f;
|
| 20 |
+
}
|
| 21 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 22 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 23 |
+
|
| 24 |
+
return pow(base, exph);
|
| 25 |
+
}
|
| 26 |
+
__kernel void flash_attn_f16(
|
| 27 |
+
const global void * q_void, ulong q_offset,
|
| 28 |
+
const global void * k_void, ulong k_offset,
|
| 29 |
+
const global void * v_void, ulong v_offset,
|
| 30 |
+
global void * o_void, ulong o_offset,
|
| 31 |
+
const float scale,
|
| 32 |
+
const int n_q,
|
| 33 |
+
const int n_kv,
|
| 34 |
+
const int is_causal,
|
| 35 |
+
const int n_head,
|
| 36 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 37 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 38 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 39 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 40 |
+
const float max_bias,
|
| 41 |
+
const float m0,
|
| 42 |
+
const float m1,
|
| 43 |
+
const int n_head_log2,
|
| 44 |
+
const float logit_softcap,
|
| 45 |
+
const int n_head_kv,
|
| 46 |
+
const global void* mask_void,
|
| 47 |
+
const ulong mask_offset,
|
| 48 |
+
const ulong mask_nb1,
|
| 49 |
+
const ulong mask_nb2,
|
| 50 |
+
const ulong mask_nb3,
|
| 51 |
+
const int mask_ne2,
|
| 52 |
+
const int mask_ne3
|
| 53 |
+
) {
|
| 54 |
+
const int tid = get_local_id(0);
|
| 55 |
+
const int block_q_idx = get_group_id(0);
|
| 56 |
+
const int head_batch_idx = get_global_id(1);
|
| 57 |
+
|
| 58 |
+
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
| 59 |
+
|
| 60 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 61 |
+
const int head_idx = head_batch_idx % n_head;
|
| 62 |
+
|
| 63 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 64 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 65 |
+
|
| 66 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 67 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 68 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 69 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 70 |
+
|
| 71 |
+
const global char* mask_base = NULL;
|
| 72 |
+
if (mask_void != NULL) {
|
| 73 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 74 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 75 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 79 |
+
if (my_query_row < n_q) {
|
| 80 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
| 81 |
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 84 |
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 89 |
+
#pragma unroll
|
| 90 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 91 |
+
o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 92 |
+
}
|
| 93 |
+
ACC_TYPE m_i = -INFINITY;
|
| 94 |
+
ACC_TYPE l_i = 0.0f;
|
| 95 |
+
|
| 96 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 97 |
+
|
| 98 |
+
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
| 99 |
+
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
| 100 |
+
|
| 101 |
+
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
| 102 |
+
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
| 103 |
+
const int row = i / DK_VEC;
|
| 104 |
+
const int col = i % DK_VEC;
|
| 105 |
+
const int k_row_idx = k_start + row;
|
| 106 |
+
if (k_row_idx < n_kv) {
|
| 107 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
| 108 |
+
l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
| 112 |
+
const int row = i / DV_VEC;
|
| 113 |
+
const int col = i % DV_VEC;
|
| 114 |
+
const int v_row_idx = k_start + row;
|
| 115 |
+
if (v_row_idx < n_kv) {
|
| 116 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
| 117 |
+
l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 121 |
+
|
| 122 |
+
if (my_query_row >= n_q) {
|
| 123 |
+
continue;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
for (int j = 0; j < BLOCK_N; j += 2) {
|
| 127 |
+
const int k_row0 = k_start + j;
|
| 128 |
+
const int k_row1 = k_start + j + 1;
|
| 129 |
+
|
| 130 |
+
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
| 131 |
+
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 134 |
+
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
| 135 |
+
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
| 136 |
+
}
|
| 137 |
+
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
| 138 |
+
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
| 139 |
+
|
| 140 |
+
if (is_causal) {
|
| 141 |
+
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
| 142 |
+
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
if (k_row0 >= n_kv) score0 = -INFINITY;
|
| 146 |
+
if (k_row1 >= n_kv) score1 = -INFINITY;
|
| 147 |
+
|
| 148 |
+
if (mask_base != NULL) {
|
| 149 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
| 150 |
+
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
| 151 |
+
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
if (logit_softcap > 0.0f) {
|
| 155 |
+
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
| 156 |
+
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
| 160 |
+
const ACC_TYPE p0 = exp(score0 - m_new);
|
| 161 |
+
const ACC_TYPE p1 = exp(score1 - m_new);
|
| 162 |
+
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
| 163 |
+
|
| 164 |
+
#pragma unroll
|
| 165 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 166 |
+
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
| 167 |
+
}
|
| 168 |
+
l_i = l_i * scale_prev + p0 + p1;
|
| 169 |
+
m_i = m_new;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if (my_query_row < n_q) {
|
| 174 |
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
| 175 |
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
| 176 |
+
if (l_i > 0.0f) {
|
| 177 |
+
const ACC_TYPE l_inv = 1.0f / l_i;
|
| 178 |
+
#pragma unroll
|
| 179 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 180 |
+
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
| 181 |
+
}
|
| 182 |
+
} else {
|
| 183 |
+
#pragma unroll
|
| 184 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 185 |
+
o_row[i] = (DATA_TYPE4)(0.0f);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
__kernel void flash_attn_f16_q1(
|
| 192 |
+
const global void * q_void, ulong q_offset,
|
| 193 |
+
const global void * k_void, ulong k_offset,
|
| 194 |
+
const global void * v_void, ulong v_offset,
|
| 195 |
+
global void * o_void, ulong o_offset,
|
| 196 |
+
const float scale,
|
| 197 |
+
const int n_q,
|
| 198 |
+
const int n_kv,
|
| 199 |
+
const int is_causal,
|
| 200 |
+
const int n_head,
|
| 201 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 202 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 203 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 204 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 205 |
+
const float max_bias,
|
| 206 |
+
const float m0,
|
| 207 |
+
const float m1,
|
| 208 |
+
const int n_head_log2,
|
| 209 |
+
const float logit_softcap,
|
| 210 |
+
const int n_head_kv,
|
| 211 |
+
const global void* mask_void,
|
| 212 |
+
const ulong mask_offset,
|
| 213 |
+
const ulong mask_nb1,
|
| 214 |
+
const ulong mask_nb2,
|
| 215 |
+
const ulong mask_nb3,
|
| 216 |
+
const int mask_ne2,
|
| 217 |
+
const int mask_ne3
|
| 218 |
+
) {
|
| 219 |
+
const int tid = get_local_id(0);
|
| 220 |
+
const int head_batch_idx = get_global_id(1);
|
| 221 |
+
|
| 222 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 223 |
+
const int head_idx = head_batch_idx % n_head;
|
| 224 |
+
|
| 225 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 226 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 227 |
+
|
| 228 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 229 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 230 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 231 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 232 |
+
|
| 233 |
+
const global char* mask_base = NULL;
|
| 234 |
+
if (mask_void != NULL) {
|
| 235 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 236 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 237 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 241 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
| 242 |
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
| 243 |
+
#pragma unroll
|
| 244 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 245 |
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 249 |
+
|
| 250 |
+
ACC_TYPE m_i = -INFINITY;
|
| 251 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 252 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 253 |
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
| 254 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 255 |
+
#pragma unroll
|
| 256 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 257 |
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
| 258 |
+
}
|
| 259 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 260 |
+
if (mask_base != NULL) {
|
| 261 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
| 262 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 263 |
+
}
|
| 264 |
+
if (logit_softcap > 0.0f) {
|
| 265 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 266 |
+
}
|
| 267 |
+
m_i = max(m_i, score);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
| 271 |
+
local_m[tid] = m_i;
|
| 272 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 273 |
+
#pragma unroll
|
| 274 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 275 |
+
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
| 276 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 277 |
+
}
|
| 278 |
+
const ACC_TYPE m_final = local_m[0];
|
| 279 |
+
|
| 280 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 283 |
+
ACC_TYPE l_i = 0.0f;
|
| 284 |
+
|
| 285 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 286 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 287 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
| 288 |
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
| 289 |
+
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
| 290 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 293 |
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
| 294 |
+
}
|
| 295 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 296 |
+
if (mask_base != NULL) {
|
| 297 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
| 298 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 299 |
+
}
|
| 300 |
+
if (logit_softcap > 0.0f) {
|
| 301 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 302 |
+
}
|
| 303 |
+
const ACC_TYPE p = exp(score - m_final);
|
| 304 |
+
l_i += p;
|
| 305 |
+
#pragma unroll
|
| 306 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 307 |
+
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
| 312 |
+
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
| 313 |
+
local_l[tid] = l_i;
|
| 314 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 315 |
+
#pragma unroll
|
| 316 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 317 |
+
if (tid < s) local_l[tid] += local_l[tid + s];
|
| 318 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
| 322 |
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
| 323 |
+
const ACC_TYPE l_final = local_l[0];
|
| 324 |
+
|
| 325 |
+
if (l_final > 0.0f) {
|
| 326 |
+
const ACC_TYPE l_inv = 1.0f / l_final;
|
| 327 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 328 |
+
local_o_comp[tid] = o_acc[i];
|
| 329 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 332 |
+
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
| 333 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 334 |
+
}
|
| 335 |
+
if (tid == 0) {
|
| 336 |
+
o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
} else if (tid == 0) {
|
| 340 |
+
#pragma unroll
|
| 341 |
+
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
| 342 |
+
}
|
| 343 |
+
}
|
ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
+
|
| 3 |
+
#define ACC_TYPE float
|
| 4 |
+
#define ACC_TYPE4 float4
|
| 5 |
+
#define DATA_TYPE float
|
| 6 |
+
#define DATA_TYPE4 float4
|
| 7 |
+
#define CONVERT_ACC4(x) (x)
|
| 8 |
+
#define CONVERT_DATA4(x) (x)
|
| 9 |
+
|
| 10 |
+
#define DK_VEC (DK/4)
|
| 11 |
+
#define DV_VEC (DV/4)
|
| 12 |
+
#define WG_SIZE (BLOCK_M)
|
| 13 |
+
#define Q1_WG_SIZE 64
|
| 14 |
+
|
| 15 |
+
inline float get_alibi_slope(
|
| 16 |
+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
| 17 |
+
) {
|
| 18 |
+
if (max_bias <= 0.0f) {
|
| 19 |
+
return 1.0f;
|
| 20 |
+
}
|
| 21 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 22 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 23 |
+
|
| 24 |
+
return pow(base, exph);
|
| 25 |
+
}
|
| 26 |
+
__kernel void flash_attn_f32(
|
| 27 |
+
const global void * q_void, ulong q_offset,
|
| 28 |
+
const global void * k_void, ulong k_offset,
|
| 29 |
+
const global void * v_void, ulong v_offset,
|
| 30 |
+
global void * o_void, ulong o_offset,
|
| 31 |
+
const float scale,
|
| 32 |
+
const int n_q,
|
| 33 |
+
const int n_kv,
|
| 34 |
+
const int is_causal,
|
| 35 |
+
const int n_head,
|
| 36 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 37 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 38 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 39 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 40 |
+
const float max_bias,
|
| 41 |
+
const float m0,
|
| 42 |
+
const float m1,
|
| 43 |
+
const int n_head_log2,
|
| 44 |
+
const float logit_softcap,
|
| 45 |
+
const int n_head_kv,
|
| 46 |
+
const global void* mask_void,
|
| 47 |
+
const ulong mask_offset,
|
| 48 |
+
const ulong mask_nb1,
|
| 49 |
+
const ulong mask_nb2,
|
| 50 |
+
const ulong mask_nb3,
|
| 51 |
+
const int mask_ne2,
|
| 52 |
+
const int mask_ne3
|
| 53 |
+
) {
|
| 54 |
+
const int tid = get_local_id(0);
|
| 55 |
+
const int block_q_idx = get_group_id(0);
|
| 56 |
+
const int head_batch_idx = get_global_id(1);
|
| 57 |
+
|
| 58 |
+
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
| 59 |
+
|
| 60 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 61 |
+
const int head_idx = head_batch_idx % n_head;
|
| 62 |
+
|
| 63 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 64 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 65 |
+
|
| 66 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 67 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 68 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 69 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 70 |
+
|
| 71 |
+
const global char* mask_base = NULL;
|
| 72 |
+
if (mask_void != NULL) {
|
| 73 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 74 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 75 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 79 |
+
if (my_query_row < n_q) {
|
| 80 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
| 81 |
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 84 |
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 89 |
+
#pragma unroll
|
| 90 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 91 |
+
o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 92 |
+
}
|
| 93 |
+
ACC_TYPE m_i = -INFINITY;
|
| 94 |
+
ACC_TYPE l_i = 0.0f;
|
| 95 |
+
|
| 96 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 97 |
+
|
| 98 |
+
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
| 99 |
+
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
| 100 |
+
|
| 101 |
+
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
| 102 |
+
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
| 103 |
+
const int row = i / DK_VEC;
|
| 104 |
+
const int col = i % DK_VEC;
|
| 105 |
+
const int k_row_idx = k_start + row;
|
| 106 |
+
if (k_row_idx < n_kv) {
|
| 107 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
| 108 |
+
l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
| 112 |
+
const int row = i / DV_VEC;
|
| 113 |
+
const int col = i % DV_VEC;
|
| 114 |
+
const int v_row_idx = k_start + row;
|
| 115 |
+
if (v_row_idx < n_kv) {
|
| 116 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
| 117 |
+
l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 121 |
+
|
| 122 |
+
if (my_query_row >= n_q) {
|
| 123 |
+
continue;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
for (int j = 0; j < BLOCK_N; j += 2) {
|
| 127 |
+
const int k_row0 = k_start + j;
|
| 128 |
+
const int k_row1 = k_start + j + 1;
|
| 129 |
+
|
| 130 |
+
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
| 131 |
+
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 134 |
+
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
| 135 |
+
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
| 136 |
+
}
|
| 137 |
+
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
| 138 |
+
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
| 139 |
+
|
| 140 |
+
if (is_causal) {
|
| 141 |
+
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
| 142 |
+
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
if (k_row0 >= n_kv) score0 = -INFINITY;
|
| 146 |
+
if (k_row1 >= n_kv) score1 = -INFINITY;
|
| 147 |
+
|
| 148 |
+
if (mask_base != NULL) {
|
| 149 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
| 150 |
+
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
| 151 |
+
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
if (logit_softcap > 0.0f) {
|
| 155 |
+
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
| 156 |
+
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
| 160 |
+
const ACC_TYPE p0 = exp(score0 - m_new);
|
| 161 |
+
const ACC_TYPE p1 = exp(score1 - m_new);
|
| 162 |
+
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
| 163 |
+
|
| 164 |
+
#pragma unroll
|
| 165 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 166 |
+
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
| 167 |
+
}
|
| 168 |
+
l_i = l_i * scale_prev + p0 + p1;
|
| 169 |
+
m_i = m_new;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if (my_query_row < n_q) {
|
| 174 |
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
| 175 |
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
| 176 |
+
if (l_i > 0.0f) {
|
| 177 |
+
const ACC_TYPE l_inv = 1.0f / l_i;
|
| 178 |
+
#pragma unroll
|
| 179 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 180 |
+
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
| 181 |
+
}
|
| 182 |
+
} else {
|
| 183 |
+
#pragma unroll
|
| 184 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 185 |
+
o_row[i] = (DATA_TYPE4)(0.0f);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
__kernel void flash_attn_f32_q1(
|
| 192 |
+
const global void * q_void, ulong q_offset,
|
| 193 |
+
const global void * k_void, ulong k_offset,
|
| 194 |
+
const global void * v_void, ulong v_offset,
|
| 195 |
+
global void * o_void, ulong o_offset,
|
| 196 |
+
const float scale,
|
| 197 |
+
const int n_q,
|
| 198 |
+
const int n_kv,
|
| 199 |
+
const int is_causal,
|
| 200 |
+
const int n_head,
|
| 201 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 202 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 203 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 204 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 205 |
+
const float max_bias,
|
| 206 |
+
const float m0,
|
| 207 |
+
const float m1,
|
| 208 |
+
const int n_head_log2,
|
| 209 |
+
const float logit_softcap,
|
| 210 |
+
const int n_head_kv,
|
| 211 |
+
const global void* mask_void,
|
| 212 |
+
const ulong mask_offset,
|
| 213 |
+
const ulong mask_nb1,
|
| 214 |
+
const ulong mask_nb2,
|
| 215 |
+
const ulong mask_nb3,
|
| 216 |
+
const int mask_ne2,
|
| 217 |
+
const int mask_ne3
|
| 218 |
+
) {
|
| 219 |
+
const int tid = get_local_id(0);
|
| 220 |
+
const int head_batch_idx = get_global_id(1);
|
| 221 |
+
|
| 222 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 223 |
+
const int head_idx = head_batch_idx % n_head;
|
| 224 |
+
|
| 225 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 226 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 227 |
+
|
| 228 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 229 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 230 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 231 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 232 |
+
|
| 233 |
+
const global char* mask_base = NULL;
|
| 234 |
+
if (mask_void != NULL) {
|
| 235 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 236 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 237 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 241 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
| 242 |
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
| 243 |
+
#pragma unroll
|
| 244 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 245 |
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 249 |
+
|
| 250 |
+
ACC_TYPE m_i = -INFINITY;
|
| 251 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 252 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 253 |
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
| 254 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 255 |
+
#pragma unroll
|
| 256 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 257 |
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
| 258 |
+
}
|
| 259 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 260 |
+
if (mask_base != NULL) {
|
| 261 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
| 262 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 263 |
+
}
|
| 264 |
+
if (logit_softcap > 0.0f) {
|
| 265 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 266 |
+
}
|
| 267 |
+
m_i = max(m_i, score);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
| 271 |
+
local_m[tid] = m_i;
|
| 272 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 273 |
+
#pragma unroll
|
| 274 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 275 |
+
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
| 276 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 277 |
+
}
|
| 278 |
+
const ACC_TYPE m_final = local_m[0];
|
| 279 |
+
|
| 280 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 283 |
+
ACC_TYPE l_i = 0.0f;
|
| 284 |
+
|
| 285 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 286 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 287 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
| 288 |
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
| 289 |
+
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
| 290 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 293 |
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
| 294 |
+
}
|
| 295 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 296 |
+
if (mask_base != NULL) {
|
| 297 |
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
| 298 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 299 |
+
}
|
| 300 |
+
if (logit_softcap > 0.0f) {
|
| 301 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 302 |
+
}
|
| 303 |
+
const ACC_TYPE p = exp(score - m_final);
|
| 304 |
+
l_i += p;
|
| 305 |
+
#pragma unroll
|
| 306 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 307 |
+
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
| 312 |
+
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
| 313 |
+
local_l[tid] = l_i;
|
| 314 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 315 |
+
#pragma unroll
|
| 316 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 317 |
+
if (tid < s) local_l[tid] += local_l[tid + s];
|
| 318 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
| 322 |
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
| 323 |
+
const ACC_TYPE l_final = local_l[0];
|
| 324 |
+
|
| 325 |
+
if (l_final > 0.0f) {
|
| 326 |
+
const ACC_TYPE l_inv = 1.0f / l_final;
|
| 327 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 328 |
+
local_o_comp[tid] = o_acc[i];
|
| 329 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 332 |
+
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
| 333 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 334 |
+
}
|
| 335 |
+
if (tid == 0) {
|
| 336 |
+
o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
} else if (tid == 0) {
|
| 340 |
+
#pragma unroll
|
| 341 |
+
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
| 342 |
+
}
|
| 343 |
+
}
|
ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
+
|
| 3 |
+
#define ACC_TYPE float
|
| 4 |
+
#define ACC_TYPE4 float4
|
| 5 |
+
#define Q_DATA_TYPE4 float4
|
| 6 |
+
#define KV_DATA_TYPE4 half4
|
| 7 |
+
#define O_DATA_TYPE4 float4
|
| 8 |
+
#define MASK_DATA_TYPE half
|
| 9 |
+
#define CONVERT_Q_ACC4(x) (x)
|
| 10 |
+
#define CONVERT_KV_ACC4(x) convert_float4(x)
|
| 11 |
+
#define CONVERT_O_DATA4(x) (x)
|
| 12 |
+
|
| 13 |
+
#define DK_VEC (DK/4)
|
| 14 |
+
#define DV_VEC (DV/4)
|
| 15 |
+
#define WG_SIZE (BLOCK_M)
|
| 16 |
+
#define Q1_WG_SIZE 64
|
| 17 |
+
|
| 18 |
+
inline float get_alibi_slope(
|
| 19 |
+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
| 20 |
+
) {
|
| 21 |
+
if (max_bias <= 0.0f) {
|
| 22 |
+
return 1.0f;
|
| 23 |
+
}
|
| 24 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 25 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 26 |
+
|
| 27 |
+
return pow(base, exph);
|
| 28 |
+
}
|
| 29 |
+
__kernel void flash_attn_f32_f16(
|
| 30 |
+
const global void * q_void, ulong q_offset,
|
| 31 |
+
const global void * k_void, ulong k_offset,
|
| 32 |
+
const global void * v_void, ulong v_offset,
|
| 33 |
+
global void * o_void, ulong o_offset,
|
| 34 |
+
const float scale,
|
| 35 |
+
const int n_q,
|
| 36 |
+
const int n_kv,
|
| 37 |
+
const int is_causal,
|
| 38 |
+
const int n_head,
|
| 39 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 40 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 41 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 42 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 43 |
+
const float max_bias,
|
| 44 |
+
const float m0,
|
| 45 |
+
const float m1,
|
| 46 |
+
const int n_head_log2,
|
| 47 |
+
const float logit_softcap,
|
| 48 |
+
const int n_head_kv,
|
| 49 |
+
const global void* mask_void,
|
| 50 |
+
const ulong mask_offset,
|
| 51 |
+
const ulong mask_nb1,
|
| 52 |
+
const ulong mask_nb2,
|
| 53 |
+
const ulong mask_nb3,
|
| 54 |
+
const int mask_ne2,
|
| 55 |
+
const int mask_ne3
|
| 56 |
+
) {
|
| 57 |
+
const int tid = get_local_id(0);
|
| 58 |
+
const int block_q_idx = get_group_id(0);
|
| 59 |
+
const int head_batch_idx = get_global_id(1);
|
| 60 |
+
|
| 61 |
+
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
| 62 |
+
|
| 63 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 64 |
+
const int head_idx = head_batch_idx % n_head;
|
| 65 |
+
|
| 66 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 67 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 68 |
+
|
| 69 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 70 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 71 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 72 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 73 |
+
|
| 74 |
+
const global char* mask_base = NULL;
|
| 75 |
+
if (mask_void != NULL) {
|
| 76 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 77 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 78 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 82 |
+
if (my_query_row < n_q) {
|
| 83 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
| 84 |
+
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
| 85 |
+
#pragma unroll
|
| 86 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 87 |
+
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 92 |
+
#pragma unroll
|
| 93 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 94 |
+
o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 95 |
+
}
|
| 96 |
+
ACC_TYPE m_i = -INFINITY;
|
| 97 |
+
ACC_TYPE l_i = 0.0f;
|
| 98 |
+
|
| 99 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 100 |
+
|
| 101 |
+
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
| 102 |
+
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
| 103 |
+
|
| 104 |
+
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
| 105 |
+
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
| 106 |
+
const int row = i / DK_VEC;
|
| 107 |
+
const int col = i % DK_VEC;
|
| 108 |
+
const int k_row_idx = k_start + row;
|
| 109 |
+
if (k_row_idx < n_kv) {
|
| 110 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
| 111 |
+
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
| 115 |
+
const int row = i / DV_VEC;
|
| 116 |
+
const int col = i % DV_VEC;
|
| 117 |
+
const int v_row_idx = k_start + row;
|
| 118 |
+
if (v_row_idx < n_kv) {
|
| 119 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
| 120 |
+
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 124 |
+
|
| 125 |
+
if (my_query_row >= n_q) {
|
| 126 |
+
continue;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for (int j = 0; j < BLOCK_N; j += 2) {
|
| 130 |
+
const int k_row0 = k_start + j;
|
| 131 |
+
const int k_row1 = k_start + j + 1;
|
| 132 |
+
|
| 133 |
+
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
| 134 |
+
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
| 135 |
+
#pragma unroll
|
| 136 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 137 |
+
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
| 138 |
+
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
| 139 |
+
}
|
| 140 |
+
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
| 141 |
+
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
| 142 |
+
|
| 143 |
+
if (is_causal) {
|
| 144 |
+
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
| 145 |
+
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
if (k_row0 >= n_kv) score0 = -INFINITY;
|
| 149 |
+
if (k_row1 >= n_kv) score1 = -INFINITY;
|
| 150 |
+
|
| 151 |
+
if (mask_base != NULL) {
|
| 152 |
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
| 153 |
+
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
| 154 |
+
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if (logit_softcap > 0.0f) {
|
| 158 |
+
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
| 159 |
+
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
| 163 |
+
const ACC_TYPE p0 = exp(score0 - m_new);
|
| 164 |
+
const ACC_TYPE p1 = exp(score1 - m_new);
|
| 165 |
+
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
| 166 |
+
|
| 167 |
+
#pragma unroll
|
| 168 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 169 |
+
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
|
| 170 |
+
}
|
| 171 |
+
l_i = l_i * scale_prev + p0 + p1;
|
| 172 |
+
m_i = m_new;
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
if (my_query_row < n_q) {
|
| 177 |
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
| 178 |
+
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
| 179 |
+
if (l_i > 0.0f) {
|
| 180 |
+
const ACC_TYPE l_inv = 1.0f / l_i;
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 183 |
+
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
|
| 184 |
+
}
|
| 185 |
+
} else {
|
| 186 |
+
#pragma unroll
|
| 187 |
+
for (int i = 0; i < DV_VEC; ++i) {
|
| 188 |
+
o_row[i] = (O_DATA_TYPE4)(0.0f);
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
__kernel void flash_attn_f32_f16_q1(
|
| 195 |
+
const global void * q_void, ulong q_offset,
|
| 196 |
+
const global void * k_void, ulong k_offset,
|
| 197 |
+
const global void * v_void, ulong v_offset,
|
| 198 |
+
global void * o_void, ulong o_offset,
|
| 199 |
+
const float scale,
|
| 200 |
+
const int n_q,
|
| 201 |
+
const int n_kv,
|
| 202 |
+
const int is_causal,
|
| 203 |
+
const int n_head,
|
| 204 |
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
| 205 |
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
| 206 |
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
| 207 |
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
| 208 |
+
const float max_bias,
|
| 209 |
+
const float m0,
|
| 210 |
+
const float m1,
|
| 211 |
+
const int n_head_log2,
|
| 212 |
+
const float logit_softcap,
|
| 213 |
+
const int n_head_kv,
|
| 214 |
+
const global void* mask_void,
|
| 215 |
+
const ulong mask_offset,
|
| 216 |
+
const ulong mask_nb1,
|
| 217 |
+
const ulong mask_nb2,
|
| 218 |
+
const ulong mask_nb3,
|
| 219 |
+
const int mask_ne2,
|
| 220 |
+
const int mask_ne3
|
| 221 |
+
) {
|
| 222 |
+
const int tid = get_local_id(0);
|
| 223 |
+
const int head_batch_idx = get_global_id(1);
|
| 224 |
+
|
| 225 |
+
const int batch_idx = head_batch_idx / n_head;
|
| 226 |
+
const int head_idx = head_batch_idx % n_head;
|
| 227 |
+
|
| 228 |
+
const int gqa_ratio = n_head / n_head_kv;
|
| 229 |
+
const int head_kv_idx = head_idx / gqa_ratio;
|
| 230 |
+
|
| 231 |
+
const global char* q_base = (const global char*)q_void + q_offset;
|
| 232 |
+
const global char* k_base = (const global char*)k_void + k_offset;
|
| 233 |
+
const global char* v_base = (const global char*)v_void + v_offset;
|
| 234 |
+
global char* o_base = (global char*)o_void + o_offset;
|
| 235 |
+
|
| 236 |
+
const global char* mask_base = NULL;
|
| 237 |
+
if (mask_void != NULL) {
|
| 238 |
+
const int mask_head_idx = head_idx % mask_ne2;
|
| 239 |
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
| 240 |
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
ACC_TYPE4 q_priv[DK_VEC];
|
| 244 |
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
| 245 |
+
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
| 246 |
+
#pragma unroll
|
| 247 |
+
for (int i = 0; i < DK_VEC; ++i) {
|
| 248 |
+
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
| 252 |
+
|
| 253 |
+
ACC_TYPE m_i = -INFINITY;
|
| 254 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 255 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 256 |
+
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
| 257 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 258 |
+
#pragma unroll
|
| 259 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 260 |
+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
| 261 |
+
}
|
| 262 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 263 |
+
if (mask_base != NULL) {
|
| 264 |
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
| 265 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 266 |
+
}
|
| 267 |
+
if (logit_softcap > 0.0f) {
|
| 268 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 269 |
+
}
|
| 270 |
+
m_i = max(m_i, score);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
| 274 |
+
local_m[tid] = m_i;
|
| 275 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 276 |
+
#pragma unroll
|
| 277 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 278 |
+
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
| 279 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 280 |
+
}
|
| 281 |
+
const ACC_TYPE m_final = local_m[0];
|
| 282 |
+
|
| 283 |
+
ACC_TYPE4 o_acc[DV_VEC];
|
| 284 |
+
#pragma unroll
|
| 285 |
+
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
| 286 |
+
ACC_TYPE l_i = 0.0f;
|
| 287 |
+
|
| 288 |
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
| 289 |
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
| 290 |
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
| 291 |
+
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
| 292 |
+
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
|
| 293 |
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (int k = 0; k < DK_VEC; k++) {
|
| 296 |
+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
| 297 |
+
}
|
| 298 |
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
| 299 |
+
if (mask_base != NULL) {
|
| 300 |
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
| 301 |
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
| 302 |
+
}
|
| 303 |
+
if (logit_softcap > 0.0f) {
|
| 304 |
+
score = logit_softcap * tanh(score / logit_softcap);
|
| 305 |
+
}
|
| 306 |
+
const ACC_TYPE p = exp(score - m_final);
|
| 307 |
+
l_i += p;
|
| 308 |
+
#pragma unroll
|
| 309 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 310 |
+
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
| 315 |
+
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
| 316 |
+
local_l[tid] = l_i;
|
| 317 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 320 |
+
if (tid < s) local_l[tid] += local_l[tid + s];
|
| 321 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
| 325 |
+
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
| 326 |
+
const ACC_TYPE l_final = local_l[0];
|
| 327 |
+
|
| 328 |
+
if (l_final > 0.0f) {
|
| 329 |
+
const ACC_TYPE l_inv = 1.0f / l_final;
|
| 330 |
+
for (int i = 0; i < DV_VEC; i++) {
|
| 331 |
+
local_o_comp[tid] = o_acc[i];
|
| 332 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 333 |
+
#pragma unroll
|
| 334 |
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
| 335 |
+
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
| 336 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 337 |
+
}
|
| 338 |
+
if (tid == 0) {
|
| 339 |
+
o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
} else if (tid == 0) {
|
| 343 |
+
#pragma unroll
|
| 344 |
+
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
|
| 345 |
+
}
|
| 346 |
+
}
|