Spaces:
Running
Running
Commit
·
aef1b4b
1
Parent(s):
b7f6691
ggml: implement quantized KV cache for FA (llama/7372)
Browse files
ggml.c
CHANGED
|
@@ -15882,9 +15882,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15882 |
GGML_ASSERT(ne0 == D);
|
| 15883 |
GGML_ASSERT(ne2 == N);
|
| 15884 |
|
| 15885 |
-
|
| 15886 |
-
GGML_ASSERT(
|
| 15887 |
-
GGML_ASSERT(
|
|
|
|
| 15888 |
|
| 15889 |
GGML_ASSERT(neq0 == D);
|
| 15890 |
GGML_ASSERT(nek0 == D);
|
|
@@ -15938,6 +15939,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15938 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 15939 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 15940 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15941 |
// loop over n_batch and n_head
|
| 15942 |
for (int ir = ir0; ir < ir1; ++ir) {
|
| 15943 |
// q indices
|
|
@@ -15945,17 +15951,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15945 |
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15946 |
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15947 |
|
| 15948 |
-
const uint32_t h = iq2; // head
|
| 15949 |
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
| 15950 |
|
| 15951 |
-
float S = 0.0f;
|
| 15952 |
-
float M = -INFINITY;
|
| 15953 |
|
| 15954 |
-
float *
|
| 15955 |
-
|
| 15956 |
-
ggml_fp16_t *
|
|
|
|
| 15957 |
|
| 15958 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15959 |
|
| 15960 |
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
| 15961 |
|
|
@@ -15967,6 +15978,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15967 |
const int iv3 = iq3 / rv3;
|
| 15968 |
const int iv2 = iq2 / rv2;
|
| 15969 |
|
|
|
|
|
|
|
|
|
|
| 15970 |
// online softmax / attention
|
| 15971 |
// loop over n_kv and n_head_kv
|
| 15972 |
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
@@ -15976,52 +15990,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15976 |
continue;
|
| 15977 |
}
|
| 15978 |
|
| 15979 |
-
float s;
|
| 15980 |
|
| 15981 |
-
|
| 15982 |
-
|
| 15983 |
-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
| 15984 |
|
| 15985 |
-
|
| 15986 |
-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
| 15987 |
-
}
|
| 15988 |
-
}
|
| 15989 |
|
| 15990 |
-
|
| 15991 |
-
&s, 0,
|
| 15992 |
-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
| 15993 |
-
Q16, 0, 1);
|
| 15994 |
|
| 15995 |
-
|
|
|
|
| 15996 |
|
| 15997 |
-
const
|
| 15998 |
|
| 15999 |
-
|
| 16000 |
-
|
|
|
|
|
|
|
|
|
|
| 16001 |
|
| 16002 |
-
|
| 16003 |
-
|
| 16004 |
-
|
|
|
|
|
|
|
|
|
|
| 16005 |
|
| 16006 |
-
// V
|
| 16007 |
-
|
| 16008 |
} else {
|
| 16009 |
-
|
| 16010 |
-
|
|
|
|
|
|
|
| 16011 |
|
| 16012 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16013 |
|
| 16014 |
-
|
| 16015 |
-
ggml_vec_mad_f16(D, V16, v16, vs);
|
| 16016 |
|
| 16017 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16018 |
}
|
| 16019 |
|
| 16020 |
-
|
| 16021 |
-
|
| 16022 |
-
|
|
|
|
| 16023 |
}
|
| 16024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16025 |
// dst indices
|
| 16026 |
const int i1 = iq1;
|
| 16027 |
const int i2 = iq2;
|
|
@@ -16031,7 +16060,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 16031 |
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
| 16032 |
|
| 16033 |
// permute(0, 2, 1, 3)
|
| 16034 |
-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1,
|
| 16035 |
}
|
| 16036 |
}
|
| 16037 |
|
|
@@ -19972,7 +20001,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 19972 |
{
|
| 19973 |
const int64_t ne00 = node->src[0]->ne[0]; // D
|
| 19974 |
|
| 19975 |
-
cur =
|
| 19976 |
} break;
|
| 19977 |
case GGML_OP_FLASH_FF:
|
| 19978 |
{
|
|
|
|
| 15882 |
GGML_ASSERT(ne0 == D);
|
| 15883 |
GGML_ASSERT(ne2 == N);
|
| 15884 |
|
| 15885 |
+
// input tensor rows must be contiguous
|
| 15886 |
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
| 15887 |
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
| 15888 |
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
| 15889 |
|
| 15890 |
GGML_ASSERT(neq0 == D);
|
| 15891 |
GGML_ASSERT(nek0 == D);
|
|
|
|
| 15939 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 15940 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 15941 |
|
| 15942 |
+
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
| 15943 |
+
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
| 15944 |
+
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
| 15945 |
+
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
| 15946 |
+
|
| 15947 |
// loop over n_batch and n_head
|
| 15948 |
for (int ir = ir0; ir < ir1; ++ir) {
|
| 15949 |
// q indices
|
|
|
|
| 15951 |
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15952 |
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15953 |
|
| 15954 |
+
const uint32_t h = iq2; // head index
|
| 15955 |
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
| 15956 |
|
| 15957 |
+
float S = 0.0f; // sum
|
| 15958 |
+
float M = -INFINITY; // maximum KQ value
|
| 15959 |
|
| 15960 |
+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
| 15961 |
+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
| 15962 |
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
| 15963 |
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
| 15964 |
|
| 15965 |
+
if (v->type == GGML_TYPE_F16) {
|
| 15966 |
+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
| 15967 |
+
} else {
|
| 15968 |
+
memset(VKQ32, 0, D*sizeof(float));
|
| 15969 |
+
}
|
| 15970 |
|
| 15971 |
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
| 15972 |
|
|
|
|
| 15978 |
const int iv3 = iq3 / rv3;
|
| 15979 |
const int iv2 = iq2 / rv2;
|
| 15980 |
|
| 15981 |
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
| 15982 |
+
q_to_vec_dot(pq, Q_q, D);
|
| 15983 |
+
|
| 15984 |
// online softmax / attention
|
| 15985 |
// loop over n_kv and n_head_kv
|
| 15986 |
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
|
|
| 15990 |
continue;
|
| 15991 |
}
|
| 15992 |
|
| 15993 |
+
float s; // KQ value
|
| 15994 |
|
| 15995 |
+
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
| 15996 |
+
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
|
|
|
| 15997 |
|
| 15998 |
+
s = s*scale + mv; // scale KQ value and apply mask
|
|
|
|
|
|
|
|
|
|
| 15999 |
|
| 16000 |
+
const float Mold = M;
|
|
|
|
|
|
|
|
|
|
| 16001 |
|
| 16002 |
+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
| 16003 |
+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
| 16004 |
|
| 16005 |
+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
| 16006 |
|
| 16007 |
+
if (v->type== GGML_TYPE_F16) {
|
| 16008 |
+
if (s > M) {
|
| 16009 |
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
| 16010 |
+
M = s;
|
| 16011 |
+
ms = expf(Mold - M);
|
| 16012 |
|
| 16013 |
+
// V = V*expf(Mold - M)
|
| 16014 |
+
ggml_vec_scale_f16(D, VKQ16, ms);
|
| 16015 |
+
} else {
|
| 16016 |
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 16017 |
+
vs = expf(s - M);
|
| 16018 |
+
}
|
| 16019 |
|
| 16020 |
+
// V += v*expf(s - M)
|
| 16021 |
+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
| 16022 |
} else {
|
| 16023 |
+
if (s > M) {
|
| 16024 |
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
| 16025 |
+
M = s;
|
| 16026 |
+
ms = expf(Mold - M);
|
| 16027 |
|
| 16028 |
+
// V = V*expf(Mold - M)
|
| 16029 |
+
ggml_vec_scale_f32(D, VKQ32, ms);
|
| 16030 |
+
} else {
|
| 16031 |
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 16032 |
+
vs = expf(s - M);
|
| 16033 |
+
}
|
| 16034 |
|
| 16035 |
+
v_to_float(v_data, V32, D);
|
|
|
|
| 16036 |
|
| 16037 |
+
// V += v*expf(s - M)
|
| 16038 |
+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
| 16039 |
+
}
|
| 16040 |
+
|
| 16041 |
+
S = S*ms + vs; // scale and increment sum with partial sum
|
| 16042 |
}
|
| 16043 |
|
| 16044 |
+
if (v->type == GGML_TYPE_F16) {
|
| 16045 |
+
for (int64_t d = 0; d < D; ++d) {
|
| 16046 |
+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
| 16047 |
+
}
|
| 16048 |
}
|
| 16049 |
|
| 16050 |
+
// V /= S
|
| 16051 |
+
const float S_inv = 1.0f/S;
|
| 16052 |
+
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
| 16053 |
+
|
| 16054 |
// dst indices
|
| 16055 |
const int i1 = iq1;
|
| 16056 |
const int i2 = iq2;
|
|
|
|
| 16060 |
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
| 16061 |
|
| 16062 |
// permute(0, 2, 1, 3)
|
| 16063 |
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
| 16064 |
}
|
| 16065 |
}
|
| 16066 |
|
|
|
|
| 20001 |
{
|
| 20002 |
const int64_t ne00 = node->src[0]->ne[0]; // D
|
| 20003 |
|
| 20004 |
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
| 20005 |
} break;
|
| 20006 |
case GGML_OP_FLASH_FF:
|
| 20007 |
{
|