Spaces:
Running
Running
| //------------------------------------------------------------------------------ | |
| // rms_norm | |
| //------------------------------------------------------------------------------ | |
| // This kernel depends on subgroup size. | |
| REQD_SUBGROUP_SIZE_32 | |
| REQD_SUBGROUP_SIZE_64 | |
| kernel void kernel_rms_norm( | |
| global void * src0, | |
| ulong offset0, | |
| global float * dst, | |
| ulong offsetd, | |
| int ne00, | |
| int ne01, | |
| int ne02, | |
| int ne03, | |
| ulong nb01, | |
| ulong nb02, | |
| ulong nb03, | |
| float eps, | |
| local float * sum // Note, the size depends on number of subgroups | |
| ) { | |
| src0 = (global void*)((global char*)src0 + offset0); | |
| dst = (global float*)((global char*)dst + offsetd); | |
| int i03 = get_group_id(2); | |
| int i02 = get_group_id(1); | |
| int i01 = get_group_id(0); | |
| global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); | |
| global float * x_scalar = (global float *) x; | |
| float4 sumf = 0; | |
| float all_sum = 0; | |
| // parallel sum | |
| for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | |
| sumf += x[i00] * x[i00]; | |
| } | |
| all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; | |
| all_sum = sub_group_reduce_add(all_sum); | |
| if (get_sub_group_local_id() == 0) { | |
| sum[get_sub_group_id()] = all_sum; | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| // broadcast | |
| for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { | |
| if (get_local_id(0) < i) { | |
| sum[get_local_id(0)] += sum[get_local_id(0) + i]; | |
| } | |
| } | |
| if (get_local_id(0) == 0) { | |
| for (int i = 4 * (ne00 / 4); i < ne00; i++) { | |
| sum[0] += x_scalar[i]; | |
| } | |
| sum[0] /= ne00; | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| const float mean = sum[0]; | |
| const float scale = 1.0f/sqrt(mean + eps); | |
| global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); | |
| global float * y_scalar = (global float *) y; | |
| for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | |
| y[i00] = x[i00] * scale; | |
| } | |
| if (get_local_id(0) == 0) { | |
| for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { | |
| y_scalar[i00] = x_scalar[i00] * scale; | |
| } | |
| } | |
| } | |
| //------------------------------------------------------------------------------ | |
| // rms_norm_mul | |
| //------------------------------------------------------------------------------ | |
| REQD_SUBGROUP_SIZE_32 | |
| REQD_SUBGROUP_SIZE_64 | |
| kernel void kernel_rms_norm_mul( | |
| global char * src0, | |
| ulong offset0, | |
| global char * src1, | |
| ulong offset1, | |
| global char * dst, | |
| ulong offsetd, | |
| int ne00, | |
| int ne01, | |
| int ne02, | |
| int ne03, | |
| ulong nb01, | |
| ulong nb02, | |
| ulong nb03, | |
| int ne10, | |
| int ne11, | |
| int ne12, | |
| int ne13, | |
| ulong nb11, | |
| ulong nb12, | |
| ulong nb13, | |
| ulong nb1, | |
| ulong nb2, | |
| ulong nb3, | |
| float eps, | |
| local float * sum | |
| ) { | |
| src0 = src0 + offset0; | |
| src1 = src1 + offset1; | |
| dst = dst + offsetd; | |
| int i03 = get_group_id(2); | |
| int i02 = get_group_id(1); | |
| int i01 = get_group_id(0); | |
| global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); | |
| global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); | |
| float sumf = 0; | |
| // parallel sum | |
| for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | |
| sumf += dot(x[i00], x[i00]); | |
| } | |
| sumf = sub_group_reduce_add(sumf); | |
| if (get_sub_group_local_id() == 0) { | |
| sum[get_sub_group_id()] = sumf; | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { | |
| if (get_local_id(0) < i) { | |
| sum[get_local_id(0)] += sum[get_local_id(0) + i]; | |
| } | |
| } | |
| if (get_local_id(0) == 0) { | |
| sum[0] /= ne00; | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| float mean = sum[0]; | |
| float scale = 1.0f/sqrt(mean + eps); | |
| global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); | |
| for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | |
| y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; | |
| } | |
| } | |