Spaces:
Running
Running
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | |
| layout (constant_id = 0) const uint32_t WorkGroupSize = 128; | |
| layout (constant_id = 1) const uint32_t Br = 1; | |
| layout (constant_id = 2) const uint32_t Bc = 32; | |
| layout (constant_id = 3) const uint32_t HSK = 32; | |
| layout (constant_id = 4) const uint32_t HSV = 32; | |
| layout (constant_id = 5) const uint32_t Clamp = 0; | |
| layout (constant_id = 6) const uint32_t D_split = 16; | |
| layout (push_constant) uniform parameter { | |
| uint32_t N; | |
| uint32_t KV; | |
| uint32_t ne1; | |
| uint32_t ne2; | |
| uint32_t ne3; | |
| uint32_t neq2; | |
| uint32_t neq3; | |
| uint32_t nek2; | |
| uint32_t nek3; | |
| uint32_t nev2; | |
| uint32_t nev3; | |
| uint32_t nem1; | |
| uint32_t nem2; | |
| uint32_t nem3; | |
| uint32_t nb01; | |
| uint32_t nb02; | |
| uint32_t nb03; | |
| uint32_t nb11; | |
| uint32_t nb12; | |
| uint32_t nb13; | |
| uint32_t nb21; | |
| uint32_t nb22; | |
| uint32_t nb23; | |
| float scale; | |
| float max_bias; | |
| float logit_softcap; | |
| uint32_t mask_n_head_log2; | |
| float m0; | |
| float m1; | |
| uint32_t gqa_ratio; | |
| uint32_t split_kv; | |
| uint32_t k_num; | |
| } p; | |
| layout (binding = 4) readonly buffer S {float data_s[];}; | |
| layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; | |
| layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; | |
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { | |
| uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); | |
| uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); | |
| uint shift = (iqs & 0x10) >> 2; | |
| vui_lo >>= shift; | |
| vui_hi >>= shift; | |
| return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); | |
| } | |
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { | |
| const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 | |
| const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; | |
| return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); | |
| } | |
| // Store column zero. This is used to save per-row m and L values for split_k. | |
| ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) | |
| { | |
| if (r < N && c == 0) { | |
| uint32_t offset = iq2 + r; | |
| data_o[o_offset + offset] = D_TYPE(elem); | |
| } | |
| return elem; | |
| } | |
| // Load the slope matrix, indexed by Q's dimension 2. | |
| ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) | |
| { | |
| const uint32_t h = iq2 + (r % p.gqa_ratio); | |
| uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; | |
| const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); | |
| const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); | |
| return ACC_TYPE(pow(base, ACC_TYPE(exph))); | |
| } | |
| // Load the sink value, indexed by Q's dimension 2. | |
| ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) | |
| { | |
| const uint32_t h = iq2 + (r % p.gqa_ratio); | |
| return ACC_TYPE(data_s[h]); | |
| } | |
| uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, | |
| iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, | |
| q_stride, k_stride, v_stride, m_stride; | |
| void init_indices() | |
| { | |
| N = p.N; | |
| KV = p.KV; | |
| i = gl_WorkGroupID.x; | |
| split_k_index = 0; | |
| if (p.k_num > 1) { | |
| i = 0; | |
| split_k_index = gl_WorkGroupID.x; | |
| } | |
| Tr = CEIL_DIV(N, Br); | |
| start_j = split_k_index * p.split_kv / Bc; | |
| end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); | |
| // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. | |
| // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. | |
| iq2 = gl_WorkGroupID.y * p.gqa_ratio; | |
| iq3 = gl_WorkGroupID.z; | |
| // broadcast factors | |
| rk2 = p.neq2/p.nek2; | |
| rk3 = p.neq3/p.nek3; | |
| rv2 = p.neq2/p.nev2; | |
| rv3 = p.neq3/p.nev3; | |
| // k indices | |
| ik3 = iq3 / rk3; | |
| ik2 = iq2 / rk2; | |
| // v indices | |
| iv3 = iq3 / rv3; | |
| iv2 = iq2 / rv2; | |
| // nb?1 are already divided by the type size and are in units of elements. | |
| // When using grouped query attention, Q is indexed by iq2, so the stride | |
| // should be nb02 (which is in bytes). | |
| q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; | |
| k_stride = p.nb11; | |
| v_stride = p.nb21; | |
| // When using grouped query attention, all rows use the same mask (stride 0). | |
| // "p.gqa_ratio >> 16" is just a roundabout way of writing zero | |
| // that prevents the compiler from folding the "&" through the select | |
| // and breaking the alignment detection. | |
| m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; | |
| } | |