Spaces:
Running
Running
| REQD_SUBGROUP_SIZE_128 | |
| __kernel void mul_mat_f16_f32( | |
| const int M, const int N, const int K, | |
| __global const void* A_void, ulong A_offset, | |
| __global const void* B_void, ulong B_offset, | |
| __global void* C_void, ulong C_offset) { | |
| __global const half* A = (__global const half* )((__global const char*)A_void + A_offset); | |
| __global const float* B = (__global const float*)((__global const char*)B_void + B_offset); | |
| __global float* C = (__global float*)((__global char*)C_void + C_offset); | |
| const int lidm = get_local_id(0); | |
| const int lidn = get_local_id(1); | |
| const int lid = lidn * WG_M + lidm; | |
| const int offsetM = get_group_id(0) * OPWM; | |
| const int offsetN = get_group_id(1) * OPWN; | |
| __local half4 Alocal[OPWM][VEC_K]; | |
| __local float4 Blocal[OPWN][VEC_K]; | |
| float sum[OPTM][OPTN]; | |
| for (int wm = 0; wm < OPTM; wm++) { | |
| for (int wn = 0; wn < OPTN; wn++) { | |
| sum[wm][wn] = 0.0f; | |
| } | |
| } | |
| const int numTiles = (K + CPWK - 1) / CPWK; | |
| const int load_row_a = lid % OPWM; | |
| const int load_vec_k_a = lid / OPWM; | |
| const int global_row_a = offsetM + load_row_a; | |
| const int load_row_b = lid % OPWN; | |
| const int load_vec_k_b = lid / OPWN; | |
| const int global_row_b = offsetN + load_row_b; | |
| for (int t = 0; t < numTiles; t++) { | |
| const int k_start = t * CPWK; | |
| const int k_vec_start_a = k_start + load_vec_k_a * 4; | |
| const int k_vec_start_b = k_start + load_vec_k_b * 4; | |
| if (global_row_a < M && k_vec_start_a < K) { | |
| if (k_vec_start_a + 3 < K) { | |
| Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a); | |
| } else { | |
| half4 tempA = (half4)(0.0h); | |
| if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a]; | |
| if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1]; | |
| if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2]; | |
| Alocal[load_row_a][load_vec_k_a] = tempA; | |
| } | |
| } else { | |
| Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h); | |
| } | |
| if (global_row_b < N && k_vec_start_b < K) { | |
| if (k_vec_start_b + 3 < K) { | |
| Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b); | |
| } else { | |
| float4 tempB = (float4)(0.0f); | |
| if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b]; | |
| if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1]; | |
| if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2]; | |
| Blocal[load_row_b][load_vec_k_b] = tempB; | |
| } | |
| } else { | |
| Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f); | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| for (int k_vec = 0; k_vec < VEC_K; k_vec++) { | |
| float4 a_fvecs[OPTM]; | |
| int current_row_a = lidm; | |
| for (int wm = 0; wm < OPTM; wm++) { | |
| a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]); | |
| current_row_a += WG_M; | |
| } | |
| float4 b_fvecs[OPTN]; | |
| int current_row_b = lidn; | |
| for (int wn = 0; wn < OPTN; wn++) { | |
| b_fvecs[wn] = Blocal[current_row_b][k_vec]; | |
| current_row_b += WG_N; | |
| } | |
| for (int wm = 0; wm < OPTM; wm++) { | |
| for (int wn = 0; wn < OPTN; wn++) { | |
| sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]); | |
| } | |
| } | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| } | |
| for (int wm = 0; wm < OPTM; wm++) { | |
| int globalRow = offsetM + lidm + wm * WG_M; | |
| if (globalRow < M) { | |
| for (int wn = 0; wn < OPTN; wn++) { | |
| int globalCol = offsetN + lidn + wn * WG_N; | |
| if (globalCol < N) { | |
| C[globalCol * M + globalRow] = sum[wm][wn]; | |
| } | |
| } | |
| } | |
| } | |
| } | |