Spaces:
Running
Running
| REQD_SUBGROUP_SIZE_64 | |
| kernel void kernel_mul_mat_f16_f16( | |
| global char * src0, | |
| ulong offset0, | |
| global char * src1, | |
| ulong offset1, | |
| global float * dst, | |
| ulong offsetd, | |
| int ne00, | |
| int ne01, | |
| int ne02, | |
| ulong nb00, | |
| ulong nb01, | |
| ulong nb02, | |
| ulong nb03, | |
| int ne10, | |
| int ne11, | |
| int ne12, | |
| ulong nb10, | |
| ulong nb11, | |
| ulong nb12, | |
| ulong nb13, | |
| int ne0, | |
| int ne1, | |
| int r2, | |
| int r3) | |
| { | |
| src0 = (global char*)((global char*)src0 + offset0); | |
| src1 = (global char*)((global char*)src1 + offset1); | |
| dst = (global float*)((global char*)dst + offsetd); | |
| int r0 = get_group_id(0); | |
| int rb = get_group_id(1)*N_F16_F16; | |
| int im = get_group_id(2); | |
| int i12 = im%ne12; | |
| int i13 = im/ne12; | |
| ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; | |
| global half * x = (global half *) (src0 + offset_src0); | |
| if (ne00 < 128) { | |
| for (int row = 0; row < N_F16_F16; ++row) { | |
| int r1 = rb + row; | |
| if (r1 >= ne11) { | |
| break; | |
| } | |
| ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; | |
| global half * y = (global half *) (src1 + offset_src1); | |
| float sumf = 0; | |
| for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { | |
| sumf += (half) x[i] * (half) y[i]; | |
| } | |
| float all_sum = sub_group_reduce_add(sumf); | |
| if (get_sub_group_local_id() == 0) { | |
| dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | |
| } | |
| } | |
| } else { | |
| global half4 * x4 = (global half4 *)x; | |
| for (int row = 0; row < N_F16_F16; ++row) { | |
| int r1 = rb + row; | |
| if (r1 >= ne11) { | |
| break; | |
| } | |
| ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; | |
| global half * y = (global half *) (src1 + offset_src1); | |
| global half4 * y4 = (global half4 *) y; | |
| float sumf = 0; | |
| for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { | |
| sumf += (half) x4[i].s0 * y4[i].s0; | |
| sumf += (half) x4[i].s1 * y4[i].s1; | |
| sumf += (half) x4[i].s2 * y4[i].s2; | |
| sumf += (half) x4[i].s3 * y4[i].s3; | |
| } | |
| float all_sum = sub_group_reduce_add(sumf); | |
| if (get_sub_group_local_id() == 0) { | |
| for (int i = 4*(ne00/4); i < ne00; ++i) { | |
| all_sum += (half) x[i] * y[i]; | |
| } | |
| dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | |
| } | |
| } | |
| } | |
| } | |