JohannesGaessler commited on
Commit
aef1b4b
·
1 Parent(s): b7f6691

ggml: implement quantized KV cache for FA (llama/7372)

Browse files
Files changed (1) hide show
  1. ggml.c +71 -42
ggml.c CHANGED
@@ -15882,9 +15882,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15882
  GGML_ASSERT(ne0 == D);
15883
  GGML_ASSERT(ne2 == N);
15884
 
15885
- GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
 
15888
 
15889
  GGML_ASSERT(neq0 == D);
15890
  GGML_ASSERT(nek0 == D);
@@ -15938,6 +15939,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15938
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15940
 
 
 
 
 
 
15941
  // loop over n_batch and n_head
15942
  for (int ir = ir0; ir < ir1; ++ir) {
15943
  // q indices
@@ -15945,17 +15951,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
 
15948
- const uint32_t h = iq2; // head
15949
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
 
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15953
 
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
 
15957
 
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
 
 
 
 
15959
 
15960
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
 
@@ -15967,6 +15978,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
  const int iv3 = iq3 / rv3;
15968
  const int iv2 = iq2 / rv2;
15969
 
 
 
 
15970
  // online softmax / attention
15971
  // loop over n_kv and n_head_kv
15972
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15990,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
  continue;
15977
  }
15978
 
15979
- float s;
15980
 
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15984
 
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
15988
- }
15989
 
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15994
 
15995
- s = s*scale + mv;
 
15996
 
15997
- const float Mold = M;
15998
 
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
 
 
 
16001
 
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
 
 
 
16005
 
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
16008
  } else {
16009
- vs = expf(s - M);
16010
- }
 
 
16011
 
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
 
 
 
 
 
16013
 
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
16016
 
16017
- S = S*ms + vs;
 
 
 
 
16018
  }
16019
 
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
 
16023
  }
16024
 
 
 
 
 
16025
  // dst indices
16026
  const int i1 = iq1;
16027
  const int i2 = iq2;
@@ -16031,7 +16060,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
 
16033
  // permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
16035
  }
16036
  }
16037
 
@@ -19972,7 +20001,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19972
  {
19973
  const int64_t ne00 = node->src[0]->ne[0]; // D
19974
 
19975
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19976
  } break;
19977
  case GGML_OP_FLASH_FF:
19978
  {
 
15882
  GGML_ASSERT(ne0 == D);
15883
  GGML_ASSERT(ne2 == N);
15884
 
15885
+ // input tensor rows must be contiguous
15886
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15887
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15888
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15889
 
15890
  GGML_ASSERT(neq0 == D);
15891
  GGML_ASSERT(nek0 == D);
 
15939
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15940
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15941
 
15942
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15943
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15944
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15945
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15946
+
15947
  // loop over n_batch and n_head
15948
  for (int ir = ir0; ir < ir1; ++ir) {
15949
  // q indices
 
15951
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15952
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15953
 
15954
+ const uint32_t h = iq2; // head index
15955
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15956
 
15957
+ float S = 0.0f; // sum
15958
+ float M = -INFINITY; // maximum KQ value
15959
 
15960
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15961
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15962
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15963
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15964
 
15965
+ if (v->type == GGML_TYPE_F16) {
15966
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15967
+ } else {
15968
+ memset(VKQ32, 0, D*sizeof(float));
15969
+ }
15970
 
15971
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15972
 
 
15978
  const int iv3 = iq3 / rv3;
15979
  const int iv2 = iq2 / rv2;
15980
 
15981
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15982
+ q_to_vec_dot(pq, Q_q, D);
15983
+
15984
  // online softmax / attention
15985
  // loop over n_kv and n_head_kv
15986
  // ref: https://arxiv.org/pdf/2112.05682.pdf
 
15990
  continue;
15991
  }
15992
 
15993
+ float s; // KQ value
15994
 
15995
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15996
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
 
15997
 
15998
+ s = s*scale + mv; // scale KQ value and apply mask
 
 
 
15999
 
16000
+ const float Mold = M;
 
 
 
16001
 
16002
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
16003
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
16004
 
16005
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
16006
 
16007
+ if (v->type== GGML_TYPE_F16) {
16008
+ if (s > M) {
16009
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16010
+ M = s;
16011
+ ms = expf(Mold - M);
16012
 
16013
+ // V = V*expf(Mold - M)
16014
+ ggml_vec_scale_f16(D, VKQ16, ms);
16015
+ } else {
16016
+ // no new maximum, ms == 1.0f, vs != 1.0f
16017
+ vs = expf(s - M);
16018
+ }
16019
 
16020
+ // V += v*expf(s - M)
16021
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16022
  } else {
16023
+ if (s > M) {
16024
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16025
+ M = s;
16026
+ ms = expf(Mold - M);
16027
 
16028
+ // V = V*expf(Mold - M)
16029
+ ggml_vec_scale_f32(D, VKQ32, ms);
16030
+ } else {
16031
+ // no new maximum, ms == 1.0f, vs != 1.0f
16032
+ vs = expf(s - M);
16033
+ }
16034
 
16035
+ v_to_float(v_data, V32, D);
 
16036
 
16037
+ // V += v*expf(s - M)
16038
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
16039
+ }
16040
+
16041
+ S = S*ms + vs; // scale and increment sum with partial sum
16042
  }
16043
 
16044
+ if (v->type == GGML_TYPE_F16) {
16045
+ for (int64_t d = 0; d < D; ++d) {
16046
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
16047
+ }
16048
  }
16049
 
16050
+ // V /= S
16051
+ const float S_inv = 1.0f/S;
16052
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
16053
+
16054
  // dst indices
16055
  const int i1 = iq1;
16056
  const int i2 = iq2;
 
16060
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16061
 
16062
  // permute(0, 2, 1, 3)
16063
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
16064
  }
16065
  }
16066
 
 
20001
  {
20002
  const int64_t ne00 = node->src[0]->ne[0]; // D
20003
 
20004
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
20005
  } break;
20006
  case GGML_OP_FLASH_FF:
20007
  {