Spaces:
Running
Running
lhez
commited on
Commit
·
d8664e4
1
Parent(s):
d7e9115
opencl: support sink in `soft_max` (attn sinks) (llama/15152)
Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2520 |
case GGML_OP_CLAMP:
|
| 2521 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2522 |
case GGML_OP_SOFT_MAX:
|
| 2523 |
-
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 2524 |
-
return op->src[2] == nullptr;
|
| 2525 |
case GGML_OP_NORM:
|
| 2526 |
case GGML_OP_RMS_NORM:
|
| 2527 |
return true;
|
|
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|
| 6594 |
GGML_ASSERT(src1->extra);
|
| 6595 |
}
|
| 6596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6597 |
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 6598 |
|
| 6599 |
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
| 6600 |
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
| 6601 |
|
| 6602 |
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
|
|
|
|
| 6603 |
|
| 6604 |
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
| 6605 |
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 6606 |
|
| 6607 |
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
|
|
|
| 6608 |
|
| 6609 |
const int ne00 = src0->ne[0];
|
| 6610 |
const int ne01 = src0->ne[1];
|
|
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|
| 6672 |
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 6673 |
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
|
| 6674 |
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 6675 |
-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &
|
| 6676 |
-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &
|
| 6677 |
-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(
|
| 6678 |
-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &
|
| 6679 |
-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(
|
| 6680 |
-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &
|
| 6681 |
-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(
|
| 6682 |
-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(
|
| 6683 |
-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(
|
| 6684 |
-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(
|
| 6685 |
-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &
|
| 6686 |
-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &
|
| 6687 |
-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &
|
| 6688 |
-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &
|
| 6689 |
-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(
|
| 6690 |
-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(
|
| 6691 |
-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &
|
| 6692 |
-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &
|
| 6693 |
-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(
|
|
|
|
|
|
|
| 6694 |
|
| 6695 |
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
| 6696 |
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
|
|
|
| 2520 |
case GGML_OP_CLAMP:
|
| 2521 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2522 |
case GGML_OP_SOFT_MAX:
|
|
|
|
|
|
|
| 2523 |
case GGML_OP_NORM:
|
| 2524 |
case GGML_OP_RMS_NORM:
|
| 2525 |
return true;
|
|
|
|
| 6592 |
GGML_ASSERT(src1->extra);
|
| 6593 |
}
|
| 6594 |
|
| 6595 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 6596 |
+
if (src2) {
|
| 6597 |
+
GGML_ASSERT(src2->extra);
|
| 6598 |
+
}
|
| 6599 |
+
|
| 6600 |
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 6601 |
|
| 6602 |
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
| 6603 |
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
| 6604 |
|
| 6605 |
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
|
| 6606 |
+
ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
|
| 6607 |
|
| 6608 |
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
| 6609 |
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 6610 |
|
| 6611 |
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
| 6612 |
+
cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
|
| 6613 |
|
| 6614 |
const int ne00 = src0->ne[0];
|
| 6615 |
const int ne01 = src0->ne[1];
|
|
|
|
| 6677 |
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 6678 |
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
|
| 6679 |
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 6680 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
|
| 6681 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
|
| 6682 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
| 6683 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
| 6684 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
| 6685 |
+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
| 6686 |
+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
| 6687 |
+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
| 6688 |
+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
|
| 6689 |
+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
|
| 6690 |
+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
|
| 6691 |
+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
|
| 6692 |
+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
|
| 6693 |
+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
|
| 6694 |
+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
|
| 6695 |
+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
|
| 6696 |
+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
|
| 6697 |
+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
|
| 6698 |
+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
|
| 6699 |
+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
|
| 6700 |
+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
|
| 6701 |
|
| 6702 |
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
| 6703 |
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
CHANGED
|
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
|
|
|
|
|
|
| 29 |
global char * dst,
|
| 30 |
ulong offsetd,
|
| 31 |
int ne00,
|
|
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
|
|
| 48 |
) {
|
| 49 |
src0 = src0 + offset0;
|
| 50 |
src1 = src1 + offset1;
|
|
|
|
| 51 |
dst = dst + offsetd;
|
| 52 |
|
| 53 |
int i03 = get_group_id(2);
|
|
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
|
|
| 60 |
|
| 61 |
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 62 |
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
|
|
|
| 63 |
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 64 |
|
| 65 |
float slope = 1.0f;
|
|
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
|
|
| 75 |
}
|
| 76 |
|
| 77 |
// parallel max
|
| 78 |
-
float4 lmax4 = -INFINITY;
|
| 79 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 80 |
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
|
| 81 |
}
|
|
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
|
|
| 92 |
}
|
| 93 |
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 98 |
pdst4[i00] /= sum;
|
|
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
| 29 |
+
global char * src2,
|
| 30 |
+
ulong offset2,
|
| 31 |
global char * dst,
|
| 32 |
ulong offsetd,
|
| 33 |
int ne00,
|
|
|
|
| 50 |
) {
|
| 51 |
src0 = src0 + offset0;
|
| 52 |
src1 = src1 + offset1;
|
| 53 |
+
src2 = src2 + offset2;
|
| 54 |
dst = dst + offsetd;
|
| 55 |
|
| 56 |
int i03 = get_group_id(2);
|
|
|
|
| 63 |
|
| 64 |
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 65 |
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
| 66 |
+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
| 67 |
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 68 |
|
| 69 |
float slope = 1.0f;
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
// parallel max
|
| 82 |
+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
| 83 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 84 |
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
|
| 85 |
}
|
|
|
|
| 96 |
}
|
| 97 |
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
| 98 |
|
| 99 |
+
float sum = sub_group_reduce_add(lsum);
|
| 100 |
+
|
| 101 |
+
if (psrc2) {
|
| 102 |
+
sum += exp(psrc2[i02] - max);
|
| 103 |
+
}
|
| 104 |
|
| 105 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 106 |
pdst4[i00] /= sum;
|
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
CHANGED
|
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
|
|
|
|
|
|
| 29 |
global char * dst,
|
| 30 |
ulong offsetd,
|
| 31 |
int ne00,
|
|
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
|
|
| 48 |
) {
|
| 49 |
src0 = src0 + offset0;
|
| 50 |
src1 = src1 + offset1;
|
|
|
|
| 51 |
dst = dst + offsetd;
|
| 52 |
|
| 53 |
int i03 = get_group_id(2);
|
|
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
|
|
| 60 |
|
| 61 |
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 62 |
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
|
|
|
| 63 |
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 64 |
|
| 65 |
float slope = 1.0f;
|
|
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
|
|
| 75 |
}
|
| 76 |
|
| 77 |
// parallel max
|
| 78 |
-
float4 lmax4 = -INFINITY;
|
| 79 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 80 |
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 81 |
}
|
|
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
|
|
| 92 |
}
|
| 93 |
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 98 |
pdst4[i00] /= sum;
|
|
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
| 29 |
+
global char * src2,
|
| 30 |
+
ulong offset2,
|
| 31 |
global char * dst,
|
| 32 |
ulong offsetd,
|
| 33 |
int ne00,
|
|
|
|
| 50 |
) {
|
| 51 |
src0 = src0 + offset0;
|
| 52 |
src1 = src1 + offset1;
|
| 53 |
+
src2 = src2 + offset2;
|
| 54 |
dst = dst + offsetd;
|
| 55 |
|
| 56 |
int i03 = get_group_id(2);
|
|
|
|
| 63 |
|
| 64 |
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 65 |
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
| 66 |
+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
| 67 |
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 68 |
|
| 69 |
float slope = 1.0f;
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
// parallel max
|
| 82 |
+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
| 83 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 84 |
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 85 |
}
|
|
|
|
| 96 |
}
|
| 97 |
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
| 98 |
|
| 99 |
+
float sum = sub_group_reduce_add(lsum);
|
| 100 |
+
|
| 101 |
+
if (psrc2) {
|
| 102 |
+
sum += exp(psrc2[i02] - max);
|
| 103 |
+
}
|
| 104 |
|
| 105 |
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
| 106 |
pdst4[i00] /= sum;
|
ggml/src/ggml-opencl/kernels/softmax_f16.cl
CHANGED
|
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
|
|
|
|
|
|
| 29 |
global char * dst,
|
| 30 |
ulong offsetd,
|
| 31 |
int ne00,
|
|
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
|
|
| 48 |
) {
|
| 49 |
src0 = src0 + offset0;
|
| 50 |
src1 = src1 + offset1;
|
|
|
|
| 51 |
dst = dst + offsetd;
|
| 52 |
|
| 53 |
int i03 = get_group_id(2);
|
|
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
|
|
| 60 |
|
| 61 |
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 62 |
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
|
|
|
| 63 |
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 64 |
|
| 65 |
float slope = 1.0f;
|
|
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
|
|
| 75 |
}
|
| 76 |
|
| 77 |
// parallel max
|
| 78 |
-
float lmax = -INFINITY;
|
| 79 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 80 |
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 81 |
}
|
|
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
|
|
| 91 |
pdst[i00] = exp_psrc0;
|
| 92 |
}
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 97 |
pdst[i00] /= sum;
|
|
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
| 29 |
+
global char * src2,
|
| 30 |
+
ulong offset2,
|
| 31 |
global char * dst,
|
| 32 |
ulong offsetd,
|
| 33 |
int ne00,
|
|
|
|
| 50 |
) {
|
| 51 |
src0 = src0 + offset0;
|
| 52 |
src1 = src1 + offset1;
|
| 53 |
+
src2 = src2 + offset2;
|
| 54 |
dst = dst + offsetd;
|
| 55 |
|
| 56 |
int i03 = get_group_id(2);
|
|
|
|
| 63 |
|
| 64 |
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 65 |
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
| 66 |
+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
| 67 |
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 68 |
|
| 69 |
float slope = 1.0f;
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
// parallel max
|
| 82 |
+
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
| 83 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 84 |
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 85 |
}
|
|
|
|
| 95 |
pdst[i00] = exp_psrc0;
|
| 96 |
}
|
| 97 |
|
| 98 |
+
float sum = sub_group_reduce_add(lsum);
|
| 99 |
+
|
| 100 |
+
if (psrc2) {
|
| 101 |
+
sum += exp(psrc2[i02] - max);
|
| 102 |
+
}
|
| 103 |
|
| 104 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 105 |
pdst[i00] /= sum;
|
ggml/src/ggml-opencl/kernels/softmax_f32.cl
CHANGED
|
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
|
|
|
|
|
|
| 29 |
global char * dst,
|
| 30 |
ulong offsetd,
|
| 31 |
int ne00,
|
|
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
|
|
| 48 |
) {
|
| 49 |
src0 = src0 + offset0;
|
| 50 |
src1 = src1 + offset1;
|
|
|
|
| 51 |
dst = dst + offsetd;
|
| 52 |
|
| 53 |
int i03 = get_group_id(2);
|
|
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
|
|
| 60 |
|
| 61 |
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 62 |
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
|
|
|
| 63 |
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 64 |
|
| 65 |
float slope = 1.0f;
|
|
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
|
|
| 75 |
}
|
| 76 |
|
| 77 |
// parallel max
|
| 78 |
-
float lmax = -INFINITY;
|
| 79 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 80 |
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 81 |
}
|
|
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
|
|
| 91 |
pdst[i00] = exp_psrc0;
|
| 92 |
}
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 97 |
pdst[i00] /= sum;
|
|
|
|
| 26 |
ulong offset0,
|
| 27 |
global char * src1,
|
| 28 |
ulong offset1,
|
| 29 |
+
global char * src2,
|
| 30 |
+
ulong offset2,
|
| 31 |
global char * dst,
|
| 32 |
ulong offsetd,
|
| 33 |
int ne00,
|
|
|
|
| 50 |
) {
|
| 51 |
src0 = src0 + offset0;
|
| 52 |
src1 = src1 + offset1;
|
| 53 |
+
src2 = src2 + offset2;
|
| 54 |
dst = dst + offsetd;
|
| 55 |
|
| 56 |
int i03 = get_group_id(2);
|
|
|
|
| 63 |
|
| 64 |
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 65 |
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
| 66 |
+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
| 67 |
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
| 68 |
|
| 69 |
float slope = 1.0f;
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
// parallel max
|
| 82 |
+
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
| 83 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 84 |
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 85 |
}
|
|
|
|
| 95 |
pdst[i00] = exp_psrc0;
|
| 96 |
}
|
| 97 |
|
| 98 |
+
float sum = sub_group_reduce_add(lsum);
|
| 99 |
+
|
| 100 |
+
if (psrc2) {
|
| 101 |
+
sum += exp(psrc2[i02] - max);
|
| 102 |
+
}
|
| 103 |
|
| 104 |
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
| 105 |
pdst[i00] /= sum;
|