lhez commited on
Commit
d8664e4
·
1 Parent(s): d7e9115

opencl: support sink in `soft_max` (attn sinks) (llama/15152)

Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2520
  case GGML_OP_CLAMP:
2521
  return op->src[0]->type == GGML_TYPE_F32;
2522
  case GGML_OP_SOFT_MAX:
2523
- // TODO: support attention sinks [TAG_ATTN_SINKS]
2524
- return op->src[2] == nullptr;
2525
  case GGML_OP_NORM:
2526
  case GGML_OP_RMS_NORM:
2527
  return true;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
6594
  GGML_ASSERT(src1->extra);
6595
  }
6596
 
 
 
 
 
 
6597
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6598
 
6599
  ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6600
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6601
 
6602
  ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
 
6603
 
6604
  cl_ulong offset0 = extra0->offset + src0->view_offs;
6605
  cl_ulong offsetd = extrad->offset + dst->view_offs;
6606
 
6607
  cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
 
6608
 
6609
  const int ne00 = src0->ne[0];
6610
  const int ne01 = src0->ne[1];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
6672
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6673
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
6674
  CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6675
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6676
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6677
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6678
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
6679
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
6680
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
6681
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6682
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
6683
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
6684
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
6685
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
6686
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
6687
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
6688
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
6689
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
6690
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
6691
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
6692
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
6693
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
 
 
6694
 
6695
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
6696
  size_t local_work_size[] = {(size_t)nth, 1, 1};
 
2520
  case GGML_OP_CLAMP:
2521
  return op->src[0]->type == GGML_TYPE_F32;
2522
  case GGML_OP_SOFT_MAX:
 
 
2523
  case GGML_OP_NORM:
2524
  case GGML_OP_RMS_NORM:
2525
  return true;
 
6592
  GGML_ASSERT(src1->extra);
6593
  }
6594
 
6595
+ const ggml_tensor * src2 = dst->src[2];
6596
+ if (src2) {
6597
+ GGML_ASSERT(src2->extra);
6598
+ }
6599
+
6600
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6601
 
6602
  ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6603
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6604
 
6605
  ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
6606
+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
6607
 
6608
  cl_ulong offset0 = extra0->offset + src0->view_offs;
6609
  cl_ulong offsetd = extrad->offset + dst->view_offs;
6610
 
6611
  cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6612
+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
6613
 
6614
  const int ne00 = src0->ne[0];
6615
  const int ne01 = src0->ne[1];
 
6677
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6678
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
6679
  CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6680
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
6681
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
6682
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
6683
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
6684
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
6685
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6686
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
6687
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
6688
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
6689
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
6690
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
6691
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
6692
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
6693
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
6694
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
6695
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
6696
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
6697
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
6698
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
6699
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
6700
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
6701
 
6702
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
6703
  size_t local_work_size[] = {(size_t)nth, 1, 1};
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl CHANGED
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
 
 
29
  global char * dst,
30
  ulong offsetd,
31
  int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
48
  ) {
49
  src0 = src0 + offset0;
50
  src1 = src1 + offset1;
 
51
  dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
60
 
61
  global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
  global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 
63
  global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
75
  }
76
 
77
  // parallel max
78
- float4 lmax4 = -INFINITY;
79
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
80
  lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
81
  }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
92
  }
93
  float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
94
 
95
- const float sum = sub_group_reduce_add(lsum);
 
 
 
 
96
 
97
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
98
  pdst4[i00] /= sum;
 
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
29
+ global char * src2,
30
+ ulong offset2,
31
  global char * dst,
32
  ulong offsetd,
33
  int ne00,
 
50
  ) {
51
  src0 = src0 + offset0;
52
  src1 = src1 + offset1;
53
+ src2 = src2 + offset2;
54
  dst = dst + offsetd;
55
 
56
  int i03 = get_group_id(2);
 
63
 
64
  global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65
  global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66
+ global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67
  global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
 
69
  float slope = 1.0f;
 
79
  }
80
 
81
  // parallel max
82
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
83
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
84
  lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
85
  }
 
96
  }
97
  float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
98
 
99
+ float sum = sub_group_reduce_add(lsum);
100
+
101
+ if (psrc2) {
102
+ sum += exp(psrc2[i02] - max);
103
+ }
104
 
105
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
106
  pdst4[i00] /= sum;
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl CHANGED
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
 
 
29
  global char * dst,
30
  ulong offsetd,
31
  int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
48
  ) {
49
  src0 = src0 + offset0;
50
  src1 = src1 + offset1;
 
51
  dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
60
 
61
  global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
  global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 
63
  global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
75
  }
76
 
77
  // parallel max
78
- float4 lmax4 = -INFINITY;
79
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
80
  lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
81
  }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
92
  }
93
  float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
94
 
95
- const float sum = sub_group_reduce_add(lsum);
 
 
 
 
96
 
97
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
98
  pdst4[i00] /= sum;
 
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
29
+ global char * src2,
30
+ ulong offset2,
31
  global char * dst,
32
  ulong offsetd,
33
  int ne00,
 
50
  ) {
51
  src0 = src0 + offset0;
52
  src1 = src1 + offset1;
53
+ src2 = src2 + offset2;
54
  dst = dst + offsetd;
55
 
56
  int i03 = get_group_id(2);
 
63
 
64
  global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65
  global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66
+ global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67
  global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
 
69
  float slope = 1.0f;
 
79
  }
80
 
81
  // parallel max
82
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
83
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
84
  lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
85
  }
 
96
  }
97
  float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
98
 
99
+ float sum = sub_group_reduce_add(lsum);
100
+
101
+ if (psrc2) {
102
+ sum += exp(psrc2[i02] - max);
103
+ }
104
 
105
  for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
106
  pdst4[i00] /= sum;
ggml/src/ggml-opencl/kernels/softmax_f16.cl CHANGED
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
 
 
29
  global char * dst,
30
  ulong offsetd,
31
  int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
48
  ) {
49
  src0 = src0 + offset0;
50
  src1 = src1 + offset1;
 
51
  dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
60
 
61
  global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
  global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 
63
  global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
75
  }
76
 
77
  // parallel max
78
- float lmax = -INFINITY;
79
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
80
  lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
81
  }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
91
  pdst[i00] = exp_psrc0;
92
  }
93
 
94
- const float sum = sub_group_reduce_add(lsum);
 
 
 
 
95
 
96
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
97
  pdst[i00] /= sum;
 
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
29
+ global char * src2,
30
+ ulong offset2,
31
  global char * dst,
32
  ulong offsetd,
33
  int ne00,
 
50
  ) {
51
  src0 = src0 + offset0;
52
  src1 = src1 + offset1;
53
+ src2 = src2 + offset2;
54
  dst = dst + offsetd;
55
 
56
  int i03 = get_group_id(2);
 
63
 
64
  global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65
  global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66
+ global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67
  global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
 
69
  float slope = 1.0f;
 
79
  }
80
 
81
  // parallel max
82
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
83
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
84
  lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
85
  }
 
95
  pdst[i00] = exp_psrc0;
96
  }
97
 
98
+ float sum = sub_group_reduce_add(lsum);
99
+
100
+ if (psrc2) {
101
+ sum += exp(psrc2[i02] - max);
102
+ }
103
 
104
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
105
  pdst[i00] /= sum;
ggml/src/ggml-opencl/kernels/softmax_f32.cl CHANGED
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
 
 
29
  global char * dst,
30
  ulong offsetd,
31
  int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
48
  ) {
49
  src0 = src0 + offset0;
50
  src1 = src1 + offset1;
 
51
  dst = dst + offsetd;
52
 
53
  int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
60
 
61
  global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62
  global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 
63
  global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
64
 
65
  float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
75
  }
76
 
77
  // parallel max
78
- float lmax = -INFINITY;
79
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
80
  lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
81
  }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
91
  pdst[i00] = exp_psrc0;
92
  }
93
 
94
- const float sum = sub_group_reduce_add(lsum);
 
 
 
 
95
 
96
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
97
  pdst[i00] /= sum;
 
26
  ulong offset0,
27
  global char * src1,
28
  ulong offset1,
29
+ global char * src2,
30
+ ulong offset2,
31
  global char * dst,
32
  ulong offsetd,
33
  int ne00,
 
50
  ) {
51
  src0 = src0 + offset0;
52
  src1 = src1 + offset1;
53
+ src2 = src2 + offset2;
54
  dst = dst + offsetd;
55
 
56
  int i03 = get_group_id(2);
 
63
 
64
  global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65
  global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66
+ global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67
  global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
 
69
  float slope = 1.0f;
 
79
  }
80
 
81
  // parallel max
82
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
83
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
84
  lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
85
  }
 
95
  pdst[i00] = exp_psrc0;
96
  }
97
 
98
+ float sum = sub_group_reduce_add(lsum);
99
+
100
+ if (psrc2) {
101
+ sum += exp(psrc2[i02] - max);
102
+ }
103
 
104
  for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
105
  pdst[i00] /= sum;