jeffbolznv commited on
Commit
d10b47b
·
1 Parent(s): 91607b6

vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (llama/10206)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -167,6 +167,7 @@ struct vk_device_struct {
167
  uint32_t subgroup_size;
168
  uint32_t shader_core_count;
169
  bool uma;
 
170
 
171
  size_t idx;
172
 
@@ -176,6 +177,7 @@ struct vk_device_struct {
176
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
177
  vk_pipeline pipeline_matmul_split_k_reduce;
178
 
 
179
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
180
 
181
  vk_matmul_pipeline pipeline_matmul_id_f32;
@@ -229,6 +231,14 @@ struct vk_device_struct {
229
  vk_pipeline pipeline_timestep_embedding_f32;
230
  vk_pipeline pipeline_pool2d_f32;
231
 
 
 
 
 
 
 
 
 
232
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
233
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
234
 
@@ -340,6 +350,40 @@ struct vk_mat_vec_id_push_constants {
340
  uint32_t nei0; uint32_t ne11;
341
  };
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  struct vk_op_push_constants {
344
  uint32_t KX;
345
  uint32_t KY;
@@ -1265,6 +1309,23 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1265
  );
1266
  }
1267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1268
  static void ggml_vk_load_shaders(vk_device& device) {
1269
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1270
 
@@ -1275,59 +1336,98 @@ static void ggml_vk_load_shaders(vk_device& device) {
1275
 
1276
  // mulmat
1277
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1278
- l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
 
 
1279
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
1280
- l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
1281
- uint32_t l_align, m_align, s_align;
 
1282
 
1283
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1284
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1285
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
1286
-
1287
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1288
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1289
- s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
1290
-
1291
- l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1292
- m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1293
- s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
1294
-
1295
- l_align = 128;
1296
- m_align = 64;
1297
- s_align = 32;
1298
-
1299
- // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1300
- // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1301
- // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1302
- // But the numbers happen to work out for 32KB shared memory size that when using the medium
1303
- // size there's enough room for everything, and we assert for this.
1304
- uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1305
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1306
- l_warptile = m_warptile;
1307
- l_wg_denoms = m_wg_denoms;
1308
- shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1309
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1310
- }
1311
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1312
- // assert mul_mat_mat_id shaders will fit.
1313
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1314
- }
1315
-
1316
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1317
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1318
- if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1319
- l_warptile_mmq = m_warptile_mmq;
1320
- l_mmq_wg_denoms = m_mmq_wg_denoms;
1321
- } else {
1322
- l_warptile_mmq = s_warptile_mmq;
1323
- l_mmq_wg_denoms = s_mmq_wg_denoms;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1324
  }
 
1325
  shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1326
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1327
- }
1328
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1329
- // assert mul_mat_mat_id shaders will fit.
1330
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
 
 
 
 
 
 
 
 
 
 
1331
  }
1332
 
1333
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
@@ -1362,6 +1462,105 @@ static void ggml_vk_load_shaders(vk_device& device) {
1362
  compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1363
  };
1364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1365
  if (device->fp16) {
1366
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1367
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -1648,15 +1847,28 @@ static vk_device ggml_vk_get_device(size_t idx) {
1648
  device->physical_device = physical_devices[dev_num];
1649
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
1650
 
 
 
1651
  bool maintenance4_support = false;
1652
  bool sm_builtins = false;
 
 
1653
 
1654
  // Check if maintenance4 is supported
1655
  for (const auto& properties : ext_props) {
1656
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1657
  maintenance4_support = true;
 
 
 
 
1658
  } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
1659
  sm_builtins = true;
 
 
 
 
 
1660
  }
1661
  }
1662
 
@@ -1679,6 +1891,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
1679
  last_struct = (VkBaseOutStructure *)&sm_props;
1680
  }
1681
 
 
 
 
 
 
 
 
 
1682
  device->physical_device.getProperties2(&props2);
1683
  device->properties = props2.properties;
1684
 
@@ -1701,20 +1921,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
1701
  device->shader_core_count = 0;
1702
  }
1703
 
1704
- bool fp16_storage = false;
1705
- bool fp16_compute = false;
1706
- bool pipeline_robustness = false;
1707
-
1708
- for (const auto& properties : ext_props) {
1709
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1710
- fp16_storage = true;
1711
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1712
- fp16_compute = true;
1713
- } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
1714
- pipeline_robustness = true;
1715
- }
1716
- }
1717
-
1718
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1719
  const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1720
 
@@ -1757,22 +1963,112 @@ static vk_device ggml_vk_get_device(size_t idx) {
1757
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1758
  vk11_features.pNext = &vk12_features;
1759
 
 
 
1760
  VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
1761
  pl_robustness_features.pNext = nullptr;
1762
  pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
1763
  pl_robustness_features.pipelineRobustness = VK_FALSE;
1764
 
1765
  if (pipeline_robustness) {
1766
- vk12_features.pNext = &pl_robustness_features;
 
1767
  device_extensions.push_back("VK_EXT_pipeline_robustness");
1768
  }
1769
 
 
 
 
 
 
 
 
 
 
 
 
1770
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1771
 
1772
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1773
 
1774
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
1775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1776
  if (!vk11_features.storageBuffer16BitAccess) {
1777
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1778
  throw std::runtime_error("Unsupported device");
@@ -2124,7 +2420,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
2124
  return ctx->device->pipeline_dequant[type];
2125
  }
2126
 
2127
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2128
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
2129
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2130
  return ctx->device->pipeline_matmul_f32;
@@ -2132,14 +2428,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2132
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2133
  return ctx->device->pipeline_matmul_f32_f16;
2134
  }
2135
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2136
- return ctx->device->pipeline_matmul_f16_f32.f32acc;
2137
- }
2138
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2139
- return ctx->device->pipeline_matmul_f16.f32acc;
 
 
 
 
 
 
 
 
 
2140
  }
2141
 
2142
- if (src1_type != GGML_TYPE_F32) {
2143
  return nullptr;
2144
  }
2145
 
@@ -2160,6 +2465,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2160
  return nullptr;
2161
  }
2162
 
 
 
 
 
2163
  return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2164
  }
2165
 
@@ -2844,6 +3153,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
2844
  break;
2845
  }
2846
 
 
 
 
 
 
 
 
 
 
 
2847
  if (m <= 32 || n <= 32) {
2848
  return aligned ? mmp->a_s : mmp->s;
2849
  }
@@ -3008,18 +3327,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3008
  }
3009
 
3010
  const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
3011
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
 
 
3012
 
3013
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3014
 
3015
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
3016
 
3017
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3018
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3019
 
3020
  if (qx_needs_dequant) {
3021
  // Fall back to dequant + f16 mulmat
3022
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3023
  }
3024
 
3025
  // Not implemented
@@ -3930,6 +4251,167 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
3930
  }
3931
  }
3932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3933
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
3934
  switch (op) {
3935
  case GGML_OP_GET_ROWS:
@@ -5044,16 +5526,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5044
  ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
5045
 
5046
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
 
5047
  for (size_t i = 0; i < num_it; i++) {
5048
- ggml_vk_ctx_begin(ctx->device, subctx);
5049
  ggml_vk_matmul(
5050
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
5051
  m, n, k,
5052
  k, k, m, k*m, k*n, m*n,
5053
  split_k, batch, batch, batch, 1, 1
5054
  );
5055
- ggml_vk_ctx_end(subctx);
5056
  }
 
5057
 
5058
  auto begin = std::chrono::high_resolution_clock::now();
5059
  ggml_vk_submit(subctx, ctx->fence);
@@ -5391,16 +5873,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5391
  ggml_vk_buffer_write(y_buf, 0, y, y_sz);
5392
 
5393
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
 
5394
  for (size_t i = 0; i < num_it; i++) {
5395
- ggml_vk_ctx_begin(ctx->device, subctx);
5396
  ggml_vk_matmul(
5397
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
5398
  m, n, k,
5399
  k, k, m, k*m, k*n, m*n,
5400
  split_k, batch, batch, batch, 1, 1
5401
  );
5402
- ggml_vk_ctx_end(subctx);
5403
  }
 
5404
 
5405
  auto begin = std::chrono::high_resolution_clock::now();
5406
 
@@ -5621,7 +6103,8 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5621
  4096, 512, 11008,
5622
  32000, 512, 4096,
5623
  };
5624
- const size_t num_it = 1;
 
5625
  for (size_t i = 0; i < vals.size(); i += 3) {
5626
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
5627
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
@@ -5676,6 +6159,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5676
  const ggml_tensor * src0 = node->src[0];
5677
  const ggml_tensor * src1 = node->src[1];
5678
  const ggml_tensor * src2 = node->src[2];
 
5679
 
5680
  switch (node->op) {
5681
  // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
@@ -5728,6 +6212,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5728
  case GGML_OP_TIMESTEP_EMBEDDING:
5729
  case GGML_OP_POOL_2D:
5730
  case GGML_OP_LEAKY_RELU:
 
5731
  break;
5732
  default:
5733
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -5920,6 +6405,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5920
  case GGML_OP_MUL_MAT_ID:
5921
  ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
5922
 
 
 
 
 
 
5923
  break;
5924
  default:
5925
  return false;
@@ -6020,6 +6510,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
6020
  break;
6021
  case GGML_OP_MUL_MAT:
6022
  case GGML_OP_MUL_MAT_ID:
 
6023
  buf = tensor->buffer;
6024
 
6025
  break;
@@ -6751,6 +7242,57 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6751
 
6752
  return true;
6753
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6754
  case GGML_OP_GET_ROWS:
6755
  {
6756
  switch (op->src[0]->type) {
@@ -7065,6 +7607,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7065
  ggml_tensor * src0 = tensor->src[0];
7066
  ggml_tensor * src1 = tensor->src[1];
7067
  ggml_tensor * src2 = tensor->src[2];
 
7068
 
7069
  struct ggml_init_params iparams = {
7070
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@@ -7077,15 +7620,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7077
  struct ggml_tensor * src0_clone = nullptr;
7078
  struct ggml_tensor * src1_clone = nullptr;
7079
  struct ggml_tensor * src2_clone = nullptr;
 
7080
  struct ggml_tensor * tensor_clone = nullptr;
7081
 
7082
  size_t src0_size;
7083
  size_t src1_size;
7084
  size_t src2_size;
 
7085
 
7086
  void * src0_buffer = nullptr;
7087
  void * src1_buffer = nullptr;
7088
  void * src2_buffer = nullptr;
 
7089
 
7090
  if (src0 != nullptr) {
7091
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -7213,8 +7759,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7213
  ggml_vk_print_tensor(src2, "src2");
7214
  }
7215
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7216
 
7217
- if (tensor->op == GGML_OP_MUL_MAT) {
 
 
 
7218
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
7219
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
7220
  tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
 
167
  uint32_t subgroup_size;
168
  uint32_t shader_core_count;
169
  bool uma;
170
+ bool coopmat2;
171
 
172
  size_t idx;
173
 
 
177
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
178
  vk_pipeline pipeline_matmul_split_k_reduce;
179
 
180
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
181
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
182
 
183
  vk_matmul_pipeline pipeline_matmul_id_f32;
 
231
  vk_pipeline pipeline_timestep_embedding_f32;
232
  vk_pipeline pipeline_pool2d_f32;
233
 
234
+ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
235
+ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
236
+ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
237
+ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
238
+ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
239
+ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
240
+ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
241
+
242
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
243
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
244
 
 
350
  uint32_t nei0; uint32_t ne11;
351
  };
352
 
353
+ struct vk_flash_attn_push_constants {
354
+ uint32_t N;
355
+ uint32_t KV;
356
+
357
+ uint32_t ne1;
358
+ uint32_t ne2;
359
+ uint32_t ne3;
360
+
361
+ uint32_t neq2;
362
+ uint32_t neq3;
363
+ uint32_t nek2;
364
+ uint32_t nek3;
365
+ uint32_t nev2;
366
+ uint32_t nev3;
367
+ uint32_t nem1;
368
+
369
+ uint32_t nb02;
370
+ uint32_t nb03;
371
+ uint32_t nb12;
372
+ uint32_t nb13;
373
+ uint32_t nb22;
374
+ uint32_t nb23;
375
+ uint32_t nb31;
376
+
377
+ float scale;
378
+ float max_bias;
379
+ float logit_softcap;
380
+
381
+ uint32_t mask;
382
+ uint32_t n_head_log2;
383
+ float m0;
384
+ float m1;
385
+ };
386
+
387
  struct vk_op_push_constants {
388
  uint32_t KX;
389
  uint32_t KY;
 
1309
  );
1310
  }
1311
 
1312
+ // number of rows/cols for flash attention shader
1313
+ static constexpr uint32_t flash_attention_num_small_rows = 32;
1314
+ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1315
+ GGML_UNUSED(clamp);
1316
+
1317
+ // small rows, large cols
1318
+ if (small_rows) {
1319
+ return {flash_attention_num_small_rows, 128};
1320
+ }
1321
+ // small cols to reduce register count
1322
+ if (ggml_is_quantized(type) || D == 256) {
1323
+ return {64, 32};
1324
+ }
1325
+ return {64, 64};
1326
+ };
1327
+
1328
+
1329
  static void ggml_vk_load_shaders(vk_device& device) {
1330
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1331
 
 
1336
 
1337
  // mulmat
1338
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1339
+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1340
+ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1341
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1342
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
1343
+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
1344
+ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
1345
+ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
1346
 
1347
+ uint32_t l_align, m_align, s_align;
1348
+ if (device->coopmat2) {
1349
+ // spec constants and tile sizes for non-quant matmul/matmul_id
1350
+ l_warptile = { 256, 128, 256, 64 };
1351
+ m_warptile = { 256, 128, 128, 64 };
1352
+ s_warptile = { 128, 32, 16, 64 };
1353
+ l_wg_denoms = {128, 256, 1 };
1354
+ m_wg_denoms = {128, 128, 1 };
1355
+ s_wg_denoms = { 32, 16, 1 };
1356
+
1357
+ // spec constants and tile sizes for quant matmul (non-Qi_K)
1358
+ l_warptile_mmq = { 256, 128, 256, 64 };
1359
+ m_warptile_mmq = { 256, 128, 128, 64 };
1360
+ s_warptile_mmq = { 256, 128, 128, 64 };
1361
+ l_mmq_wg_denoms = { 128, 256, 1 };
1362
+ m_mmq_wg_denoms = { 128, 128, 1 };
1363
+ s_mmq_wg_denoms = { 128, 128, 1 };
1364
+
1365
+ // spec constants and tile sizes for quant matmul (Qi_K)
1366
+ l_warptile_mmq_k = { 256, 128, 512, 16 };
1367
+ m_warptile_mmq_k = { 256, 128, 256, 16 };
1368
+ s_warptile_mmq_k = { 256, 32, 128, 64 };
1369
+ l_mmq_wg_denoms_k = { 128, 512, 1 };
1370
+ m_mmq_wg_denoms_k = { 128, 256, 1 };
1371
+ s_mmq_wg_denoms_k = { 32, 128, 1 };
1372
+
1373
+ // spec constants and tile sizes for quant matmul_id
1374
+ l_warptile_mmqid = { 256, 128, 128, 16 };
1375
+ m_warptile_mmqid = { 256, 128, 64, 16 };
1376
+ s_warptile_mmqid = { 256, 64, 64, 16 };
1377
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1378
+ m_mmqid_wg_denoms = { 128, 64, 1 };
1379
+ s_mmqid_wg_denoms = { 64, 64, 1 };
1380
+
1381
+ l_align = 128;
1382
+ m_align = 64;
1383
+ s_align = 32;
1384
+ } else {
1385
+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1386
+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1387
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
1388
+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1389
+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1390
+ s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
1391
+ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1392
+ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1393
+ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
1394
+ l_align = 128;
1395
+ m_align = 64;
1396
+ s_align = 32;
1397
+
1398
+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1399
+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1400
+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1401
+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
1402
+ // size there's enough room for everything, and we assert for this.
1403
+ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1404
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1405
+ l_warptile = m_warptile;
1406
+ l_wg_denoms = m_wg_denoms;
1407
+ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1408
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1409
+ }
1410
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1411
+ // assert mul_mat_mat_id shaders will fit.
1412
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1413
  }
1414
+
1415
  shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1416
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1417
+ if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1418
+ l_warptile_mmq = m_warptile_mmq;
1419
+ l_mmq_wg_denoms = m_mmq_wg_denoms;
1420
+ } else {
1421
+ l_warptile_mmq = s_warptile_mmq;
1422
+ l_mmq_wg_denoms = s_mmq_wg_denoms;
1423
+ }
1424
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1425
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1426
+ }
1427
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1428
+ // assert mul_mat_mat_id shaders will fit.
1429
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1430
+ }
1431
  }
1432
 
1433
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
 
1462
  compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1463
  };
1464
 
1465
+ #if defined(VK_NV_cooperative_matrix2)
1466
+ if (device->coopmat2) {
1467
+
1468
+ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1469
+ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
1470
+ };
1471
+
1472
+ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1473
+ // For large number of rows, 128 invocations seems to work best.
1474
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1475
+ // can't use 256 for D==80.
1476
+ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1477
+ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1478
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1479
+ };
1480
+
1481
+ #define CREATE_FA2(TYPE, NAMELC, D) \
1482
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1483
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1484
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1485
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1486
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1487
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1488
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1489
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1490
+
1491
+ #define CREATE_FA(TYPE, NAMELC) \
1492
+ CREATE_FA2(TYPE, NAMELC, 64) \
1493
+ CREATE_FA2(TYPE, NAMELC, 80) \
1494
+ CREATE_FA2(TYPE, NAMELC, 96) \
1495
+ CREATE_FA2(TYPE, NAMELC, 112) \
1496
+ CREATE_FA2(TYPE, NAMELC, 128) \
1497
+ CREATE_FA2(TYPE, NAMELC, 256)
1498
+
1499
+ CREATE_FA(GGML_TYPE_F16, f16)
1500
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0)
1501
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1)
1502
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0)
1503
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1)
1504
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0)
1505
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1506
+ //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
1507
+ //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
1508
+ //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1509
+ //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1510
+ //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1511
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
1512
+ #undef CREATE_FA
1513
+
1514
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1515
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1516
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1517
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1518
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1519
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1520
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1521
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1522
+
1523
+ // Create 2 variants, {f16,f32} accumulator
1524
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1525
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1526
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1527
+
1528
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1529
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1530
+
1531
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1532
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1533
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1534
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1535
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1536
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1537
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1538
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1539
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1540
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1541
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1542
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1543
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1544
+
1545
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1546
+ CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1547
+ CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1548
+
1549
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1550
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1551
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1552
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1553
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1554
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1555
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1556
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1557
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1558
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1559
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1560
+ #undef CREATE_MM
1561
+ #undef CREATE_MM2
1562
+ } else
1563
+ #endif
1564
  if (device->fp16) {
1565
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1566
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
 
1847
  device->physical_device = physical_devices[dev_num];
1848
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
1849
 
1850
+ bool fp16_storage = false;
1851
+ bool fp16_compute = false;
1852
  bool maintenance4_support = false;
1853
  bool sm_builtins = false;
1854
+ bool pipeline_robustness = false;
1855
+ bool coopmat2_support = false;
1856
 
1857
  // Check if maintenance4 is supported
1858
  for (const auto& properties : ext_props) {
1859
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1860
  maintenance4_support = true;
1861
+ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1862
+ fp16_storage = true;
1863
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1864
+ fp16_compute = true;
1865
  } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
1866
  sm_builtins = true;
1867
+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
1868
+ pipeline_robustness = true;
1869
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
1870
+ !getenv("GGML_VULKAN_DISABLE_COOPMAT2")) {
1871
+ coopmat2_support = true;
1872
  }
1873
  }
1874
 
 
1891
  last_struct = (VkBaseOutStructure *)&sm_props;
1892
  }
1893
 
1894
+ #if defined(VK_NV_cooperative_matrix2)
1895
+ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
1896
+ if (coopmat2_support) {
1897
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
1898
+ last_struct = (VkBaseOutStructure *)&coopmat2_props;
1899
+ }
1900
+ #endif
1901
+
1902
  device->physical_device.getProperties2(&props2);
1903
  device->properties = props2.properties;
1904
 
 
1921
  device->shader_core_count = 0;
1922
  }
1923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1924
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1925
  const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1926
 
 
1963
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1964
  vk11_features.pNext = &vk12_features;
1965
 
1966
+ last_struct = (VkBaseOutStructure *)&vk12_features;
1967
+
1968
  VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
1969
  pl_robustness_features.pNext = nullptr;
1970
  pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
1971
  pl_robustness_features.pipelineRobustness = VK_FALSE;
1972
 
1973
  if (pipeline_robustness) {
1974
+ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
1975
+ last_struct = (VkBaseOutStructure *)&pl_robustness_features;
1976
  device_extensions.push_back("VK_EXT_pipeline_robustness");
1977
  }
1978
 
1979
+ #if defined(VK_NV_cooperative_matrix2)
1980
+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
1981
+ coopmat2_features.pNext = nullptr;
1982
+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
1983
+ if (coopmat2_support) {
1984
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
1985
+ last_struct = (VkBaseOutStructure *)&coopmat2_features;
1986
+ device_extensions.push_back("VK_NV_cooperative_matrix2");
1987
+ }
1988
+ #endif
1989
+
1990
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1991
 
1992
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1993
 
1994
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
1995
 
1996
+ if (coopmat2_support) {
1997
+ #if defined(VK_NV_cooperative_matrix2)
1998
+ if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
1999
+ coopmat2_features.cooperativeMatrixFlexibleDimensions &&
2000
+ coopmat2_features.cooperativeMatrixReductions &&
2001
+ coopmat2_features.cooperativeMatrixConversions &&
2002
+ coopmat2_features.cooperativeMatrixPerElementOperations &&
2003
+ coopmat2_features.cooperativeMatrixTensorAddressing &&
2004
+ coopmat2_features.cooperativeMatrixBlockLoads &&
2005
+ vk12_features.bufferDeviceAddress) {
2006
+
2007
+ std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
2008
+ uint32_t count = 0;
2009
+
2010
+ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
2011
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
2012
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
2013
+ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
2014
+
2015
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
2016
+
2017
+ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
2018
+ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
2019
+ flexible_dimensions.resize(count, empty_prop);
2020
+
2021
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
2022
+
2023
+ bool found_fp16_128 = false,
2024
+ found_fp16_256 = false,
2025
+ found_fp32_128 = false,
2026
+ found_fp32_256 = false;
2027
+ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
2028
+ // with 32x16x16 and 256 with 32x32x16.
2029
+ for (auto &prop : flexible_dimensions) {
2030
+ if (prop.saturatingAccumulation == VK_FALSE &&
2031
+ prop.scope == VK_SCOPE_WORKGROUP_KHR &&
2032
+ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2033
+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2034
+
2035
+ if (prop.workgroupInvocations == 128 &&
2036
+ prop.MGranularity <= 32 &&
2037
+ prop.NGranularity <= 16 &&
2038
+ prop.KGranularity <= 16) {
2039
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2040
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2041
+ found_fp16_128 = true;
2042
+ }
2043
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2044
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2045
+ found_fp32_128 = true;
2046
+ }
2047
+ }
2048
+ if (prop.workgroupInvocations == 256 &&
2049
+ prop.MGranularity <= 32 &&
2050
+ prop.NGranularity <= 32 &&
2051
+ prop.KGranularity <= 16) {
2052
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2053
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2054
+ found_fp16_256 = true;
2055
+ }
2056
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2057
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2058
+ found_fp32_256 = true;
2059
+ }
2060
+ }
2061
+ }
2062
+ }
2063
+ if (found_fp16_128 && found_fp16_256 &&
2064
+ found_fp32_128 && found_fp32_256 &&
2065
+ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
2066
+ device->coopmat2 = true;
2067
+ }
2068
+ }
2069
+ #endif
2070
+ }
2071
+
2072
  if (!vk11_features.storageBuffer16BitAccess) {
2073
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
2074
  throw std::runtime_error("Unsupported device");
 
2420
  return ctx->device->pipeline_dequant[type];
2421
  }
2422
 
2423
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
2424
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
2425
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2426
  return ctx->device->pipeline_matmul_f32;
 
2428
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2429
  return ctx->device->pipeline_matmul_f32_f16;
2430
  }
2431
+ if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) {
2432
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2433
+ return ctx->device->pipeline_matmul_f16_f32.f16acc;
2434
+ }
2435
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2436
+ return ctx->device->pipeline_matmul_f16.f16acc;
2437
+ }
2438
+ } else {
2439
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2440
+ return ctx->device->pipeline_matmul_f16_f32.f32acc;
2441
+ }
2442
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2443
+ return ctx->device->pipeline_matmul_f16.f32acc;
2444
+ }
2445
  }
2446
 
2447
+ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
2448
  return nullptr;
2449
  }
2450
 
 
2465
  return nullptr;
2466
  }
2467
 
2468
+ if (ctx->device->coopmat2) {
2469
+ assert(src1_type == GGML_TYPE_F16);
2470
+ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
2471
+ }
2472
  return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2473
  }
2474
 
 
3153
  break;
3154
  }
3155
 
3156
+ if (ctx->device->coopmat2) {
3157
+ if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
3158
+ return aligned ? mmp->a_l : mmp->l;
3159
+ }
3160
+ if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
3161
+ return aligned ? mmp->a_m : mmp->m;
3162
+ }
3163
+ return aligned ? mmp->a_s : mmp->s;
3164
+ }
3165
+
3166
  if (m <= 32 || n <= 32) {
3167
  return aligned ? mmp->a_s : mmp->s;
3168
  }
 
3327
  }
3328
 
3329
  const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
3330
+ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3331
+ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
3332
+ !ggml_vk_dim01_contiguous(src1);
3333
 
3334
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3335
 
3336
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
3337
 
3338
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3339
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3340
 
3341
  if (qx_needs_dequant) {
3342
  // Fall back to dequant + f16 mulmat
3343
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
3344
  }
3345
 
3346
  // Not implemented
 
4251
  }
4252
  }
4253
 
4254
+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
4255
+ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
4256
+ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
4257
+ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
4258
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4259
+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4260
+
4261
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
4262
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
4263
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
4264
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
4265
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
4266
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
4267
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
4268
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
4269
+
4270
+ const uint32_t nem1 = mask ? mask->ne[1] : 0;
4271
+ const uint32_t nbm1 = mask ? mask->nb[1] : 0;
4272
+
4273
+ const uint32_t D = neq0;
4274
+ const uint32_t N = neq1;
4275
+ const uint32_t KV = nek1;
4276
+
4277
+ GGML_ASSERT(ne0 == D);
4278
+ GGML_ASSERT(ne2 == N);
4279
+
4280
+ // input tensor rows must be contiguous
4281
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
4282
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
4283
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
4284
+
4285
+ GGML_ASSERT(neq0 == D);
4286
+ GGML_ASSERT(nek0 == D);
4287
+ GGML_ASSERT(nev0 == D);
4288
+
4289
+ GGML_ASSERT(neq1 == N);
4290
+ GGML_ASSERT(nev0 == D);
4291
+
4292
+ GGML_ASSERT(nev1 == nek1);
4293
+
4294
+ // dst cannot be transposed or permuted
4295
+ GGML_ASSERT(nb0 == sizeof(float));
4296
+ GGML_ASSERT(nb0 <= nb1);
4297
+ GGML_ASSERT(nb1 <= nb2);
4298
+ GGML_ASSERT(nb2 <= nb3);
4299
+
4300
+ assert(dst->type == GGML_TYPE_F32);
4301
+ assert(q->type == GGML_TYPE_F32);
4302
+ assert(k->type == v->type);
4303
+
4304
+ vk_pipeline *pipelines;
4305
+ // XXX TODO other backends may be changing accumulator precision to default to f32 soon
4306
+ bool f32acc = dst->op_params[3] == GGML_PREC_F32;
4307
+ bool small_rows = N <= flash_attention_num_small_rows;
4308
+ switch (D) {
4309
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
4310
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
4311
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
4312
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
4313
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
4314
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
4315
+ default:
4316
+ assert(!"unsupported D value");
4317
+ return;
4318
+ }
4319
+ assert(pipelines);
4320
+
4321
+ bool aligned = (KV % pipelines[1]->align) == 0;
4322
+ vk_pipeline pipeline = pipelines[aligned];
4323
+ assert(pipeline);
4324
+
4325
+ if (dryrun) {
4326
+ // Request descriptor sets
4327
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
4328
+ return;
4329
+ }
4330
+
4331
+ float scale = 1.0f;
4332
+ float max_bias = 0.0f;
4333
+ float logit_softcap = 0.0f;
4334
+
4335
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
4336
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
4337
+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
4338
+
4339
+ if (logit_softcap != 0) {
4340
+ scale /= logit_softcap;
4341
+ }
4342
+
4343
+ const uint32_t n_head_kv = neq2;
4344
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
4345
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4346
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4347
+
4348
+ ggml_vk_sync_buffers(subctx);
4349
+
4350
+ vk_buffer d_Q, d_K, d_V, d_D, d_M;
4351
+ uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
4352
+
4353
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
4354
+
4355
+ if (ctx->device->uma) {
4356
+ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4357
+ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4358
+ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4359
+ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4360
+ Q_uma = d_Q != nullptr;
4361
+ K_uma = d_K != nullptr;
4362
+ V_uma = d_V != nullptr;
4363
+ D_uma = d_D != nullptr;
4364
+ if (mask) {
4365
+ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4366
+ M_uma = d_M != nullptr;
4367
+ }
4368
+ }
4369
+
4370
+
4371
+ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
4372
+ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
4373
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
4374
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
4375
+
4376
+ if (!Q_uma) {
4377
+ d_Q = q_buf_ctx->dev_buffer;
4378
+ q_buf_offset = vk_tensor_offset(q) + q->view_offs;
4379
+ }
4380
+ if (!K_uma) {
4381
+ d_K = k_buf_ctx->dev_buffer;
4382
+ k_buf_offset = vk_tensor_offset(k) + k->view_offs;
4383
+ }
4384
+ if (!V_uma) {
4385
+ d_V = v_buf_ctx->dev_buffer;
4386
+ v_buf_offset = vk_tensor_offset(v) + v->view_offs;
4387
+ }
4388
+ if (!D_uma) {
4389
+ d_D = d_buf_ctx->dev_buffer;
4390
+ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
4391
+ }
4392
+
4393
+ if (!M_uma) {
4394
+ d_M = d_Q;
4395
+ m_buf_offset = q_buf_offset;
4396
+ if (mask) {
4397
+ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
4398
+ d_M = m_buf_ctx->dev_buffer;
4399
+ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
4400
+ }
4401
+ }
4402
+
4403
+ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
4404
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4405
+ {
4406
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
4407
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
4408
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
4409
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
4410
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
4411
+ },
4412
+ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
4413
+ }
4414
+
4415
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
4416
  switch (op) {
4417
  case GGML_OP_GET_ROWS:
 
5526
  ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
5527
 
5528
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
5529
+ ggml_vk_ctx_begin(ctx->device, subctx);
5530
  for (size_t i = 0; i < num_it; i++) {
 
5531
  ggml_vk_matmul(
5532
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
5533
  m, n, k,
5534
  k, k, m, k*m, k*n, m*n,
5535
  split_k, batch, batch, batch, 1, 1
5536
  );
 
5537
  }
5538
+ ggml_vk_ctx_end(subctx);
5539
 
5540
  auto begin = std::chrono::high_resolution_clock::now();
5541
  ggml_vk_submit(subctx, ctx->fence);
 
5873
  ggml_vk_buffer_write(y_buf, 0, y, y_sz);
5874
 
5875
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
5876
+ ggml_vk_ctx_begin(ctx->device, subctx);
5877
  for (size_t i = 0; i < num_it; i++) {
 
5878
  ggml_vk_matmul(
5879
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
5880
  m, n, k,
5881
  k, k, m, k*m, k*n, m*n,
5882
  split_k, batch, batch, batch, 1, 1
5883
  );
 
5884
  }
5885
+ ggml_vk_ctx_end(subctx);
5886
 
5887
  auto begin = std::chrono::high_resolution_clock::now();
5888
 
 
6103
  4096, 512, 11008,
6104
  32000, 512, 4096,
6105
  };
6106
+ const size_t num_it = 100;
6107
+
6108
  for (size_t i = 0; i < vals.size(); i += 3) {
6109
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
6110
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
 
6159
  const ggml_tensor * src0 = node->src[0];
6160
  const ggml_tensor * src1 = node->src[1];
6161
  const ggml_tensor * src2 = node->src[2];
6162
+ const ggml_tensor * src3 = node->src[3];
6163
 
6164
  switch (node->op) {
6165
  // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
 
6212
  case GGML_OP_TIMESTEP_EMBEDDING:
6213
  case GGML_OP_POOL_2D:
6214
  case GGML_OP_LEAKY_RELU:
6215
+ case GGML_OP_FLASH_ATTN_EXT:
6216
  break;
6217
  default:
6218
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
 
6405
  case GGML_OP_MUL_MAT_ID:
6406
  ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
6407
 
6408
+ break;
6409
+
6410
+ case GGML_OP_FLASH_ATTN_EXT:
6411
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
6412
+
6413
  break;
6414
  default:
6415
  return false;
 
6510
  break;
6511
  case GGML_OP_MUL_MAT:
6512
  case GGML_OP_MUL_MAT_ID:
6513
+ case GGML_OP_FLASH_ATTN_EXT:
6514
  buf = tensor->buffer;
6515
 
6516
  break;
 
7242
 
7243
  return true;
7244
  } break;
7245
+ case GGML_OP_FLASH_ATTN_EXT:
7246
+ {
7247
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7248
+ if (!ggml_vk_get_device(ctx->device)->coopmat2) {
7249
+ return false;
7250
+ }
7251
+ switch (op->src[0]->ne[0]) {
7252
+ case 64:
7253
+ case 80:
7254
+ case 96:
7255
+ case 112:
7256
+ case 128:
7257
+ case 256:
7258
+ break;
7259
+ default:
7260
+ return false;
7261
+ }
7262
+ if (op->src[0]->type != GGML_TYPE_F32) {
7263
+ return false;
7264
+ }
7265
+ if (op->type != GGML_TYPE_F32) {
7266
+ return false;
7267
+ }
7268
+ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
7269
+ return false;
7270
+ }
7271
+ // It's straightforward to support different K/V dequant, but would
7272
+ // significantly increase the number of pipelines
7273
+ if (op->src[1]->type != op->src[2]->type) {
7274
+ return false;
7275
+ }
7276
+ switch (op->src[1]->type) {
7277
+ case GGML_TYPE_F16:
7278
+ case GGML_TYPE_Q4_0:
7279
+ case GGML_TYPE_Q4_1:
7280
+ case GGML_TYPE_Q5_0:
7281
+ case GGML_TYPE_Q5_1:
7282
+ case GGML_TYPE_Q8_0:
7283
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
7284
+ //case GGML_TYPE_Q2_K:
7285
+ //case GGML_TYPE_Q3_K:
7286
+ //case GGML_TYPE_Q4_K:
7287
+ //case GGML_TYPE_Q5_K:
7288
+ //case GGML_TYPE_Q6_K:
7289
+ case GGML_TYPE_IQ4_NL:
7290
+ break;
7291
+ default:
7292
+ return false;
7293
+ }
7294
+ return true;
7295
+ }
7296
  case GGML_OP_GET_ROWS:
7297
  {
7298
  switch (op->src[0]->type) {
 
7607
  ggml_tensor * src0 = tensor->src[0];
7608
  ggml_tensor * src1 = tensor->src[1];
7609
  ggml_tensor * src2 = tensor->src[2];
7610
+ ggml_tensor * src3 = tensor->src[3];
7611
 
7612
  struct ggml_init_params iparams = {
7613
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
 
7620
  struct ggml_tensor * src0_clone = nullptr;
7621
  struct ggml_tensor * src1_clone = nullptr;
7622
  struct ggml_tensor * src2_clone = nullptr;
7623
+ struct ggml_tensor * src3_clone = nullptr;
7624
  struct ggml_tensor * tensor_clone = nullptr;
7625
 
7626
  size_t src0_size;
7627
  size_t src1_size;
7628
  size_t src2_size;
7629
+ size_t src3_size;
7630
 
7631
  void * src0_buffer = nullptr;
7632
  void * src1_buffer = nullptr;
7633
  void * src2_buffer = nullptr;
7634
+ void * src3_buffer = nullptr;
7635
 
7636
  if (src0 != nullptr) {
7637
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
 
7759
  ggml_vk_print_tensor(src2, "src2");
7760
  }
7761
  }
7762
+ if (src3 != nullptr) {
7763
+ src3_clone = ggml_dup_tensor(ggml_ctx, src3);
7764
+
7765
+ src3_size = ggml_nbytes(src3);
7766
+
7767
+ src3_buffer = malloc(src3_size);
7768
+ src3_clone->data = src3_buffer;
7769
+ if (ggml_backend_buffer_is_host(src3->buffer)) {
7770
+ memcpy(src3_clone->data, src3->data, src3_size);
7771
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
7772
+ } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
7773
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
7774
+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
7775
+ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
7776
+ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
7777
+ for (int i3 = 0; i3 < src3->ne[3]; i3++) {
7778
+ for (int i2 = 0; i2 < src3->ne[2]; i2++) {
7779
+ const int idx = i3*src3->ne[2] + i2;
7780
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
7781
+ }
7782
+ }
7783
+
7784
+ src3_clone->nb[0] = src3->nb[0];
7785
+ src3_clone->nb[1] = src3->nb[1];
7786
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
7787
+ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
7788
+ }
7789
+ } else {
7790
+ if (offset + src3_size >= buffer_gpu->size) {
7791
+ src3_size = buffer_gpu->size - offset;
7792
+ }
7793
+ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
7794
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
7795
+ }
7796
+ } else {
7797
+ GGML_ABORT("fatal error");
7798
+ }
7799
+
7800
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
7801
+ ggml_vk_print_tensor(src3, "src3");
7802
+ }
7803
+ }
7804
 
7805
+ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
7806
+ const float *params = (const float *)tensor->op_params;
7807
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
7808
+ } else if (tensor->op == GGML_OP_MUL_MAT) {
7809
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
7810
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
7811
  tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt CHANGED
@@ -1,7 +1,9 @@
1
  find_package (Threads REQUIRED)
 
2
 
3
  set(TARGET vulkan-shaders-gen)
4
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
5
  install(TARGETS ${TARGET} RUNTIME)
6
  target_compile_features(${TARGET} PRIVATE cxx_std_17)
7
  target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
 
 
1
  find_package (Threads REQUIRED)
2
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
3
 
4
  set(TARGET vulkan-shaders-gen)
5
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
6
  install(TARGETS ${TARGET} RUNTIME)
7
  target_compile_features(${TARGET} PRIVATE cxx_std_17)
8
  target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
9
+ target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include "types.comp"
3
+
4
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
5
+ block_q4_0_packed16 block;
6
+ };
7
+
8
+ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
9
+ {
10
+ const float16_t d = bl.block.d;
11
+ const uint idx = coordInBlock[1];
12
+ const uint shift = (idx & 0x10) >> 2;
13
+ uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
14
+ qs >>= shift;
15
+ qs &= 0xF;
16
+ float16_t ret = (float16_t(qs) - float16_t(8)) * d;
17
+ return ret;
18
+ }
19
+
20
+ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
21
+ block_q4_1 block;
22
+ };
23
+
24
+ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
25
+ {
26
+ const float16_t d = bl.block.d;
27
+ const float16_t m = bl.block.m;
28
+ const uint idx = coordInBlock[1];
29
+ const uint iqs = idx & 0xF;
30
+ const uint shift = (idx & 0x10) >> 2;
31
+ uint32_t qs = bl.block.qs[iqs];
32
+ qs >>= shift;
33
+ qs &= 0xF;
34
+ float16_t ret = float16_t(qs) * d + m;
35
+ return ret;
36
+ }
37
+
38
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
39
+ block_q5_0 block;
40
+ };
41
+
42
+ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
43
+ {
44
+ const float16_t d = bl.block.d;
45
+ const uint idx = coordInBlock[1];
46
+ const uint iqs = idx & 0xF;
47
+
48
+ const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
49
+ const uint qh = ((uint_qh >> idx) << 4) & 0x10;
50
+
51
+ const uint shift = (idx & 0x10) >> 2;
52
+ uint32_t qs = bl.block.qs[iqs];
53
+ qs >>= shift;
54
+ qs &= 0xF;
55
+
56
+ float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
57
+ return ret;
58
+ }
59
+
60
+ layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
61
+ block_q5_1 block;
62
+ };
63
+
64
+ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
65
+ {
66
+ const float16_t d = bl.block.d;
67
+ const float16_t m = bl.block.m;
68
+ const uint idx = coordInBlock[1];
69
+ const uint iqs = idx & 0xF;
70
+
71
+ const uint uint_qh = bl.block.qh;
72
+ const uint qh = ((uint_qh >> idx) << 4) & 0x10;
73
+
74
+ const uint shift = (idx & 0x10) >> 2;
75
+ uint32_t qs = bl.block.qs[iqs];
76
+ qs >>= shift;
77
+ qs &= 0xF;
78
+
79
+ float16_t ret = float16_t(qs | qh) * d + m;
80
+ return ret;
81
+ }
82
+
83
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
84
+ block_q8_0_packed16 block;
85
+ };
86
+
87
+ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
88
+ {
89
+ const float16_t d = bl.block.d;
90
+ const uint idx = coordInBlock[1];
91
+ const uint iqs = idx;
92
+
93
+ // Load 16b and select the byte for this element
94
+ int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
95
+ float16_t ret = float16_t(qs) * d;
96
+ return ret;
97
+ }
98
+
99
+ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
100
+ block_q2_K block;
101
+ };
102
+
103
+ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
104
+ {
105
+ const f16vec2 d = bl.block.d;
106
+ const uint idx = coordInBlock[1];
107
+ const uint iqs = idx;
108
+
109
+ const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
110
+ const uint scalesi = iqs / 16; // 0..15
111
+ const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
112
+
113
+ uint32_t qs = bl.block.qs[qsi];
114
+ const uint scales = bl.block.scales[scalesi];
115
+ float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
116
+ return ret;
117
+ }
118
+
119
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
120
+ block_q3_K block;
121
+ };
122
+
123
+ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
124
+ {
125
+ const uint idx = coordInBlock[1];
126
+ const uint iqs = idx;
127
+
128
+ const uint n = iqs / 128; // 0,1
129
+ const uint qsi = n * 32 + (iqs % 32); // 0..63
130
+ const uint hmi = (iqs % 32); // 0..31
131
+ const uint j = (iqs % 128) / 8; // 0..15
132
+ const uint is = iqs / 16; // 0..15
133
+ const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
134
+ const uint qsshift = halfsplit * 2; // 0,2,4,6
135
+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
136
+
137
+ uint32_t scaleidx0 = (is < 8) ? is : (is-8);
138
+ uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
139
+ uint32_t scaleidx1 = is + 8 - (is/4)*4;
140
+ uint32_t scaleidx1shift = (is/4)*2;
141
+
142
+ const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
143
+
144
+ const float16_t dl = bl.block.d * float16_t(us - 32);
145
+
146
+ float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
147
+
148
+ return ret;
149
+ }
150
+
151
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
152
+ block_q4_K block;
153
+ };
154
+
155
+ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
156
+ {
157
+ const uint idx = coordInBlock[1];
158
+ const uint iqs = idx;
159
+
160
+ const uint n = iqs / 64; // 0,1,2,3
161
+ const uint b = (iqs % 64) / 32; // 0,1
162
+ const uint is = (idx & 0xE0) >> 5; // 0..7
163
+ const uint qsi = n * 32 + (iqs % 32); // 0..127
164
+
165
+ const f16vec2 loadd = bl.block.d;
166
+
167
+ uint32_t sc;
168
+ uint32_t mbyte;
169
+
170
+ uint32_t scidx0 = (is < 4) ? is : (is + 4);
171
+ uint32_t scidx1 = (is < 4) ? is : (is - 4);
172
+ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
173
+ uint32_t scidxshift1 = (is < 4) ? 0 : 2;
174
+ uint32_t mbidx0 = is + 4;
175
+ uint32_t mbidx1 = (is < 4) ? is + 4 : is;
176
+ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
177
+ uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
178
+ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
179
+ uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
180
+
181
+ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
182
+ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
183
+
184
+ const float16_t d = loadd.x * float16_t(sc);
185
+ const float16_t m = loadd.y * float16_t(mbyte);
186
+
187
+ uint32_t dmask = 0xF << (b * 4);
188
+
189
+ float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
190
+
191
+ return ret;
192
+ }
193
+
194
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
195
+ block_q5_K block;
196
+ };
197
+
198
+ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
199
+ {
200
+ const uint idx = coordInBlock[1];
201
+ const uint iqs = idx;
202
+
203
+ const uint n = iqs / 64; // 0,1,2,3
204
+ const uint b = (iqs % 64) / 32; // 0,1
205
+ const uint is = (idx & 0xE0) >> 5; // 0..7
206
+ const uint qsi = n * 32 + (iqs % 32); // 0..127
207
+ const uint qhi = (iqs % 32); // 0..31
208
+
209
+ const uint8_t hm = uint8_t(1 << (iqs / 32));
210
+
211
+ const f16vec2 loadd = bl.block.d;
212
+
213
+ uint32_t sc;
214
+ uint32_t mbyte;
215
+
216
+ uint32_t scidx0 = (is < 4) ? is : (is + 4);
217
+ uint32_t scidx1 = (is < 4) ? is : (is - 4);
218
+ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
219
+ uint32_t scidxshift1 = (is < 4) ? 0 : 2;
220
+ uint32_t mbidx0 = is + 4;
221
+ uint32_t mbidx1 = (is < 4) ? is + 4 : is;
222
+ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
223
+ uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
224
+ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
225
+ uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
226
+
227
+ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
228
+ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
229
+
230
+ const float16_t d = loadd.x * float16_t(sc);
231
+ const float16_t m = loadd.y * float16_t(mbyte);
232
+
233
+ uint32_t dmask = 0xF << (b * 4);
234
+
235
+ float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
236
+
237
+ return ret;
238
+ }
239
+
240
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
241
+ block_q6_K block;
242
+ };
243
+
244
+ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
245
+ {
246
+ const uint idx = coordInBlock[1];
247
+ const uint iqs = idx;
248
+
249
+ const uint n = iqs / 128; // 0,1
250
+ const uint b = (iqs % 128) / 64; // 0,1
251
+ const uint is_b = (iqs % 32) / 16; // 0,1
252
+ const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
253
+ const uint is = 8 * n + qhshift + is_b; // 0..15
254
+ const uint qsi = n * 64 + (iqs % 64); // 0..127
255
+ const uint qhi = n * 32 + (iqs % 32); // 0..63
256
+
257
+ const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
258
+
259
+ float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
260
+
261
+ return ret;
262
+ }
263
+
264
+ #if defined(DATA_A_IQ4_NL)
265
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
266
+ block_iq4_nl block;
267
+ };
268
+
269
+ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
270
+ {
271
+ const float16_t d = bl.block.d;
272
+ const uint idx = coordInBlock[1];
273
+ const uint iqs = idx & 0xF;
274
+ const uint shift = (idx & 0x10) >> 2;
275
+ uint32_t qs = bl.block.qs[iqs];
276
+ qs >>= shift;
277
+ qs &= 0xF;
278
+ float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
279
+ return ret;
280
+ }
281
+ #endif
282
+
283
+ #if defined(DATA_A_Q4_0)
284
+ #define dequantFuncA dequantFuncQ4_0
285
+ #elif defined(DATA_A_Q4_1)
286
+ #define dequantFuncA dequantFuncQ4_1
287
+ #elif defined(DATA_A_Q5_0)
288
+ #define dequantFuncA dequantFuncQ5_0
289
+ #elif defined(DATA_A_Q5_1)
290
+ #define dequantFuncA dequantFuncQ5_1
291
+ #elif defined(DATA_A_Q8_0)
292
+ #define dequantFuncA dequantFuncQ8_0
293
+ #elif defined(DATA_A_Q2_K)
294
+ #define dequantFuncA dequantFuncQ2_K
295
+ #elif defined(DATA_A_Q3_K)
296
+ #define dequantFuncA dequantFuncQ3_K
297
+ #elif defined(DATA_A_Q4_K)
298
+ #define dequantFuncA dequantFuncQ4_K
299
+ #elif defined(DATA_A_Q5_K)
300
+ #define dequantFuncA dequantFuncQ5_K
301
+ #elif defined(DATA_A_Q6_K)
302
+ #define dequantFuncA dequantFuncQ6_K
303
+ #elif defined(DATA_A_IQ4_NL)
304
+ #define dequantFuncA dequantFuncIQ4_NL
305
+ #endif
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+ #extension GL_EXT_shader_16bit_storage : require
5
+
6
+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7
+ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
9
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
10
+
11
+ #extension GL_KHR_memory_scope_semantics : enable
12
+ #extension GL_KHR_cooperative_matrix : enable
13
+ #extension GL_NV_cooperative_matrix2 : enable
14
+ #extension GL_EXT_buffer_reference : enable
15
+ #extension GL_KHR_shader_subgroup_ballot : enable
16
+ #extension GL_KHR_shader_subgroup_vote : enable
17
+ #extension GL_EXT_null_initializer : enable
18
+
19
+ #include "types.comp"
20
+ #include "dequant_funcs_cm2.comp"
21
+
22
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
23
+
24
+ layout (constant_id = 1) const uint32_t Br = 32;
25
+ layout (constant_id = 2) const uint32_t Bc = 32;
26
+ layout (constant_id = 3) const uint32_t D = 32;
27
+ layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
28
+
29
+ layout (push_constant) uniform parameter {
30
+ uint32_t N;
31
+ uint32_t KV;
32
+
33
+ uint32_t ne1;
34
+ uint32_t ne2;
35
+ uint32_t ne3;
36
+
37
+ uint32_t neq2;
38
+ uint32_t neq3;
39
+ uint32_t nek2;
40
+ uint32_t nek3;
41
+ uint32_t nev2;
42
+ uint32_t nev3;
43
+ uint32_t nem1;
44
+
45
+ uint32_t nb02;
46
+ uint32_t nb03;
47
+ uint32_t nb12;
48
+ uint32_t nb13;
49
+ uint32_t nb22;
50
+ uint32_t nb23;
51
+ uint32_t nb31;
52
+
53
+ float scale;
54
+ float max_bias;
55
+ float logit_softcap;
56
+
57
+ uint32_t mask;
58
+ uint32_t n_head_log2;
59
+ float m0;
60
+ float m1;
61
+ } p;
62
+
63
+ layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
64
+ layout (binding = 1) readonly buffer K {uint8_t data_k[];};
65
+ layout (binding = 2) readonly buffer V {uint8_t data_v[];};
66
+ layout (binding = 3) readonly buffer M {uint8_t data_m[];};
67
+ layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
68
+
69
+ #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
70
+
71
+ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
72
+ return max(x, y);
73
+ }
74
+
75
+ ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
76
+ return x;
77
+ }
78
+
79
+ // Replace matrix elements >= numRows or numCols with 'replace'
80
+ ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
81
+ if (row >= numRows || col >= numCols) {
82
+ return replace;
83
+ }
84
+ return elem;
85
+ }
86
+
87
+ ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
88
+ {
89
+ return exp(elem);
90
+ }
91
+
92
+ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
93
+ {
94
+ return max(elem0, elem1);
95
+ }
96
+
97
+ #if defined(BLOCK_SIZE)
98
+ #define DECODEFUNC , DEQUANTFUNC
99
+ #else
100
+ #define DECODEFUNC
101
+ #endif
102
+
103
+ void main() {
104
+ #if defined(DATA_A_IQ4_NL)
105
+ init_iq4nl_shmem();
106
+ #endif
107
+
108
+ const uint32_t N = p.N;
109
+ const uint32_t KV = p.KV;
110
+
111
+ const uint32_t Tr = CEIL_DIV(N, Br);
112
+ const uint32_t Tc = CEIL_DIV(KV, Bc);
113
+
114
+ const uint32_t i = gl_WorkGroupID.x;
115
+
116
+ const uint32_t iq2 = gl_WorkGroupID.y;
117
+ const uint32_t iq3 = gl_WorkGroupID.z;
118
+
119
+ // broadcast factors
120
+ const uint32_t rk2 = p.neq2/p.nek2;
121
+ const uint32_t rk3 = p.neq3/p.nek3;
122
+
123
+ const uint32_t rv2 = p.neq2/p.nev2;
124
+ const uint32_t rv3 = p.neq3/p.nev3;
125
+
126
+ // k indices
127
+ const uint32_t ik3 = iq3 / rk3;
128
+ const uint32_t ik2 = iq2 / rk2;
129
+
130
+ // v indices
131
+ const uint32_t iv3 = iq3 / rv3;
132
+ const uint32_t iv2 = iq2 / rv2;
133
+
134
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
135
+ tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
136
+ tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
137
+
138
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
139
+
140
+ #if defined(BLOCK_SIZE)
141
+ tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
142
+ tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
143
+ #endif
144
+
145
+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
146
+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
147
+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
148
+
149
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
150
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
151
+
152
+ uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
153
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
154
+
155
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
156
+ Qf16 *= float16_t(p.scale);
157
+
158
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
159
+
160
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
161
+
162
+ L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
163
+ M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
164
+
165
+ ACC_TYPE slope = ACC_TYPE(1.0);
166
+
167
+ // ALiBi
168
+ if (p.max_bias > 0.0f) {
169
+ const uint32_t h = iq2;
170
+
171
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
172
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
173
+
174
+ slope = pow(base, ACC_TYPE(exph));
175
+ }
176
+
177
+ [[dont_unroll]]
178
+ for (uint32_t j = 0; j < Tc; ++j) {
179
+
180
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
181
+
182
+ coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
183
+
184
+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
185
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
186
+ S = coopMatMulAdd(Qf16, K_T, S);
187
+
188
+ if (p.logit_softcap != 0.0f) {
189
+ [[unroll]]
190
+ for (int k = 0; k < S.length(); ++k) {
191
+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
192
+ }
193
+ }
194
+
195
+ if (p.mask != 0) {
196
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
197
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
198
+
199
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
200
+
201
+ coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
202
+
203
+ S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
204
+ }
205
+
206
+ // Clear padding elements to -inf, so they don't contribute to rowmax
207
+ if (Clamp != 0 &&
208
+ ((j + 1) * Bc > KV ||
209
+ (i + 1) * Br > N)) {
210
+
211
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
212
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
213
+
214
+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
215
+ }
216
+
217
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
218
+
219
+ coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
220
+
221
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
222
+
223
+ // M = max(rowmax, Mold)
224
+ // P = e^(S - M)
225
+ // eM = e^(Mold - M)
226
+ coopMatPerElementNV(M, rowmax, Max, Mold);
227
+ coopMatPerElementNV(P, S - M, Exp);
228
+ coopMatPerElementNV(eM, Mold - M, Exp);
229
+
230
+ // Clear padding elements to 0, so they don't contribute to rowsum
231
+ if (Clamp != 0 &&
232
+ ((j + 1) * Bc > KV ||
233
+ (i + 1) * Br > N)) {
234
+
235
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
236
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
237
+
238
+ coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
239
+ }
240
+
241
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
242
+
243
+ // compute rowsum by multiplying by matrix of all ones.
244
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
245
+
246
+ rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
247
+ rowsum = coopMatMulAdd(P_A, One, rowsum);
248
+
249
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
250
+ uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
251
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
252
+
253
+ L = eM*L + rowsum;
254
+
255
+ // This is the "diagonal" matrix in the paper, but since we do componentwise
256
+ // multiply rather than matrix multiply it has the diagonal element smeared
257
+ // across the row
258
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
259
+
260
+ // resize eM by using smear/reduce
261
+ coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
262
+
263
+ O = eMdiag * O;
264
+
265
+ O = coopMatMulAdd(P_A, V, O);
266
+ }
267
+
268
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
269
+
270
+ // resize L by using smear/reduce
271
+ coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
272
+
273
+ [[unroll]]
274
+ for (int k = 0; k < Ldiag.length(); ++k) {
275
+ Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
276
+ }
277
+
278
+ O = Ldiag*O;
279
+
280
+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
281
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
282
+
283
+ // permute dimensions
284
+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
285
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
286
+
287
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
288
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
289
+ }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+ #extension GL_EXT_shader_16bit_storage : require
5
+
6
+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7
+ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
9
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
10
+
11
+ #extension GL_KHR_memory_scope_semantics : enable
12
+ #extension GL_KHR_cooperative_matrix : enable
13
+ #extension GL_NV_cooperative_matrix2 : enable
14
+ #extension GL_EXT_buffer_reference : enable
15
+ #extension GL_KHR_shader_subgroup_ballot : enable
16
+ #extension GL_KHR_shader_subgroup_vote : enable
17
+
18
+ #include "types.comp"
19
+
20
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
21
+
22
+ layout (constant_id = 1) const uint BM = 64;
23
+ layout (constant_id = 2) const uint BN = 64;
24
+ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
25
+
26
+ layout (push_constant) uniform parameter
27
+ {
28
+ uint M;
29
+ uint N;
30
+ uint K;
31
+ uint stride_a;
32
+ uint stride_b;
33
+ uint stride_d;
34
+
35
+ uint batch_stride_a;
36
+ uint batch_stride_b;
37
+ uint batch_stride_d;
38
+
39
+ #ifdef MUL_MAT_ID
40
+ uint nei0;
41
+ uint nei1;
42
+ uint nbi1;
43
+ uint ne11;
44
+ #else
45
+ uint k_split;
46
+ uint ne02;
47
+ uint ne12;
48
+ uint broadcast2;
49
+ uint broadcast3;
50
+ #endif
51
+ } p;
52
+
53
+
54
+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
55
+ layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
56
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
57
+
58
+ #if QUANT_K > 1
59
+ #define DECODEFUNCA , dequantFuncA
60
+ #define MAT_A_TYPE float16_t
61
+
62
+ #include "dequant_funcs_cm2.comp"
63
+
64
+ #else
65
+ #define DECODEFUNCA
66
+ #define MAT_A_TYPE A_TYPE
67
+ #endif
68
+
69
+ #define MAT_B_TYPE B_TYPE
70
+
71
+ #ifdef MUL_MAT_ID
72
+ layout (binding = 3) readonly buffer IDS {int data_ids[];};
73
+
74
+ shared u16vec4 row_ids[3072];
75
+
76
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
77
+ B_TYPE b[];
78
+ };
79
+
80
+ uint _ne1;
81
+ shared uint _ne1_sh;
82
+
83
+ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
84
+ {
85
+ const uint row_i = blockCoords[0];
86
+
87
+ if (row_i >= _ne1) {
88
+ return B_TYPE(0.0);
89
+ }
90
+
91
+ const u16vec4 row_idx = row_ids[row_i];
92
+ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
93
+
94
+ return ret;
95
+ }
96
+
97
+ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
98
+ {
99
+ uint dr = ir * BM + r;
100
+ uint dc = ic * BN + c;
101
+
102
+ if (dr < p.M && dc < _ne1) {
103
+ uint row_i = dc;
104
+ const u16vec4 row_idx = row_ids[row_i];
105
+ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
106
+ }
107
+ return elem;
108
+ }
109
+
110
+ #endif
111
+
112
+ void main() {
113
+ #if defined(DATA_A_IQ4_NL)
114
+ init_iq4nl_shmem();
115
+ #endif
116
+
117
+ #ifdef MUL_MAT_ID
118
+ const uint expert_idx = gl_GlobalInvocationID.z;
119
+ #else
120
+ const uint batch_idx = gl_GlobalInvocationID.z;
121
+
122
+ const uint i13 = batch_idx / p.ne12;
123
+ const uint i12 = batch_idx % p.ne12;
124
+
125
+ const uint i03 = i13 / p.broadcast3;
126
+ const uint i02 = i12 / p.broadcast2;
127
+
128
+ const uint batch_idx_a = i03 * p.ne02 + i02;
129
+ #endif
130
+
131
+ const uint blocks_m = (p.M + BM - 1) / BM;
132
+ const uint ir = gl_WorkGroupID.x % blocks_m;
133
+ const uint ik = gl_WorkGroupID.x / blocks_m;
134
+ const uint ic = gl_WorkGroupID.y;
135
+
136
+ #ifdef MUL_MAT_ID
137
+ // Spread the search across all elements in the first subgroup
138
+ if (gl_SubgroupID == 0) {
139
+ _ne1 = 0;
140
+ uint num_elements = p.nei1 * p.nei0;
141
+
142
+ for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
143
+ bool in_range = i < num_elements;
144
+ uint ii0 = i % p.nei0;
145
+ uint ii1 = i / p.nei0;
146
+ uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
147
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
148
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
149
+ if (in_range && id == expert_idx) {
150
+ row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
151
+ }
152
+ _ne1 += subgroupBallotBitCount(ballot);
153
+ }
154
+ _ne1_sh = _ne1;
155
+ }
156
+
157
+ barrier();
158
+
159
+ _ne1 = _ne1_sh;
160
+
161
+ // Workgroup has no work
162
+ if (ic * BN >= _ne1) return;
163
+ #endif
164
+
165
+ #ifdef MUL_MAT_ID
166
+ uint start_k = 0;
167
+ const uint end_k = p.K;
168
+ #else
169
+ uint start_k = ik * p.k_split;
170
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
171
+ #endif
172
+
173
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
174
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
175
+
176
+ #ifdef MUL_MAT_ID
177
+ uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
178
+ uint pos_b = 0;
179
+ #else
180
+ uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
181
+ uint pos_b = batch_idx * p.batch_stride_b;
182
+ #endif
183
+
184
+ uint stride_a = p.stride_a / QUANT_K;
185
+ uint stride_b = p.stride_b;
186
+
187
+ // Hint to the compiler that values are aligned (want 16B alignment).
188
+ // Quants are always block-aligned, no alignment needed.
189
+ #if ALIGNED
190
+ #if QUANT_K == 1
191
+ stride_a &= ~7;
192
+ #endif
193
+ stride_b &= ~7;
194
+ #endif
195
+
196
+ // Create layouts for both clamped and unclamped accesses
197
+ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
198
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
199
+ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
200
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
201
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
202
+
203
+ #if QUANT_K > 1
204
+ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
205
+ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
206
+ #endif
207
+
208
+ // Use end_k rather than p.K as the dimension because that's what
209
+ // we need to bound check against when using split_k
210
+ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
211
+ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
212
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
213
+ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
214
+ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
215
+
216
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
217
+
218
+ #if !defined(MUL_MAT_ID)
219
+ // Detect a fast path where all loads are entirely in bounds and no clamping is required
220
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
221
+ #if QUANT_K == 1
222
+ (stride_a % 8) == 0 &&
223
+ #endif
224
+ (stride_b % 8) == 0 && (start_k % 8) == 0) {
225
+ // Hint to the compiler that values are aligned (want 16B alignment)
226
+ start_k &= ~7;
227
+ stride_b &= ~7;
228
+ #if QUANT_K == 1
229
+ stride_a &= ~7;
230
+ #endif
231
+
232
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
233
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
234
+
235
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
236
+
237
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
238
+
239
+ coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
240
+ coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
241
+
242
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
243
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
244
+
245
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
246
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
247
+
248
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
249
+ }
250
+ } else
251
+ #endif // !defined(MUL_MAT_ID)
252
+ {
253
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
254
+
255
+ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
256
+
257
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
258
+
259
+ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
260
+
261
+ [[dont_unroll]]
262
+ for (uint block_k = start_k; block_k < end_k; block_k += BK) {
263
+
264
+ coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
265
+ coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
266
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
267
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
268
+
269
+ // Clamping is expensive, so detect different code paths for each combination
270
+ // of A and B needing clamping.
271
+ bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
272
+ #ifdef MUL_MAT_ID
273
+ bool unclampedB = true;
274
+ #else
275
+ bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
276
+ #endif
277
+ if (unclampedA && unclampedB) {
278
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
279
+ #ifdef MUL_MAT_ID
280
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
281
+ #else
282
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
283
+ #endif
284
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
285
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
286
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
287
+ } else if (unclampedA && !unclampedB) {
288
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
289
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
290
+
291
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
292
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
293
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
294
+ } else if (!unclampedA && unclampedB) {
295
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
296
+ #ifdef MUL_MAT_ID
297
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
298
+ #else
299
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
300
+ #endif
301
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
302
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
303
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
304
+ } else if (!unclampedA && !unclampedB) {
305
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
306
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
307
+
308
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
309
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
310
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
311
+ }
312
+ }
313
+ }
314
+
315
+ // Convert from ACC_TYPE to D_TYPE
316
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
317
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
318
+
319
+ #ifdef MUL_MAT_ID
320
+ // Call callback to store each element, remapping row through shared memory
321
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
322
+ #else
323
+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
324
+
325
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
326
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
327
+ #endif
328
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -30,6 +30,8 @@
30
  #include <fcntl.h>
31
  #endif
32
 
 
 
33
  #define ASYNCIO_CONCURRENCY 64
34
 
35
  std::mutex lock;
@@ -196,15 +198,17 @@ static uint32_t compile_count = 0;
196
  static std::mutex compile_count_mutex;
197
  static std::condition_variable compile_count_cond;
198
 
199
- void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
200
- std::string name = _name + (fp16 ? "" : "_fp32");
201
  std::string out_fname = join_paths(output_dir, name + ".spv");
202
  std::string in_path = join_paths(input_dir, in_fname);
203
 
 
 
204
  #ifdef _WIN32
205
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
206
  #else
207
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
208
  #endif
209
 
210
  #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
@@ -254,7 +258,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
254
  }
255
 
256
  static std::vector<std::future<void>> compiles;
257
- void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
258
  {
259
  // wait until fewer than N compiles are in progress.
260
  // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@@ -265,15 +269,15 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
265
  }
266
  compile_count++;
267
  }
268
- compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16));
269
  }
270
 
271
- void matmul_shaders(bool fp16, bool matmul_id) {
272
- std::string load_vec = fp16 ? "8" : "4";
273
- std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
274
- std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
275
 
276
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
277
  std::string shader_name = "matmul";
278
 
279
  if (matmul_id) {
@@ -285,21 +289,31 @@ void matmul_shaders(bool fp16, bool matmul_id) {
285
  base_dict["FLOAT16"] = "1";
286
  }
287
 
 
 
 
 
288
  // Shaders with f16 B_TYPE
289
- string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
290
- string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
291
 
292
- string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
293
- string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
294
 
295
  for (const auto& tname : type_names) {
296
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
297
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
298
- std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
299
  // For aligned matmul loads
300
- std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
301
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
302
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
 
 
 
 
 
 
303
  }
304
  }
305
 
@@ -307,11 +321,50 @@ void process_shaders() {
307
  std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
308
  std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
309
 
 
310
  for (const auto& fp16 : {false, true}) {
311
- matmul_shaders(fp16, false);
312
- matmul_shaders(fp16, true);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  }
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  for (const auto& tname : type_names) {
316
  // mul mat vec
317
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 
30
  #include <fcntl.h>
31
  #endif
32
 
33
+ #include <vulkan/vulkan_core.h>
34
+
35
  #define ASYNCIO_CONCURRENCY 64
36
 
37
  std::mutex lock;
 
198
  static std::mutex compile_count_mutex;
199
  static std::condition_variable compile_count_cond;
200
 
201
+ void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
202
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
203
  std::string out_fname = join_paths(output_dir, name + ".spv");
204
  std::string in_path = join_paths(input_dir, in_fname);
205
 
206
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
207
+
208
  #ifdef _WIN32
209
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
210
  #else
211
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o", out_fname};
212
  #endif
213
 
214
  #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
 
258
  }
259
 
260
  static std::vector<std::future<void>> compiles;
261
+ void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
262
  {
263
  // wait until fewer than N compiles are in progress.
264
  // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
 
269
  }
270
  compile_count++;
271
  }
272
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
273
  }
274
 
275
+ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
276
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
277
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
278
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
279
 
280
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
281
  std::string shader_name = "matmul";
282
 
283
  if (matmul_id) {
 
289
  base_dict["FLOAT16"] = "1";
290
  }
291
 
292
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
293
+
294
+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
295
+
296
  // Shaders with f16 B_TYPE
297
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
298
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
299
 
300
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
301
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
302
 
303
  for (const auto& tname : type_names) {
304
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
305
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
306
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
307
  // For aligned matmul loads
308
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
309
+
310
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
311
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
312
+
313
+ if (tname != "f16" && tname != "f32") {
314
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
315
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
316
+ }
317
  }
318
  }
319
 
 
321
  std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
322
  std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
323
 
324
+ // matmul
325
  for (const auto& fp16 : {false, true}) {
326
+ for (const auto& matmul_id : {false, true}) {
327
+ for (const auto& coopmat2 : {false, true}) {
328
+ for (const auto& f16acc : {false, true}) {
329
+ #if !defined(VK_NV_cooperative_matrix2)
330
+ if (coopmat2) {
331
+ continue;
332
+ }
333
+ #endif
334
+ if (coopmat2 && !fp16) {
335
+ continue;
336
+ }
337
+ if (!coopmat2 && f16acc) {
338
+ continue;
339
+ }
340
+ matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
341
+ }
342
+ }
343
+ }
344
  }
345
 
346
+ #if defined(VK_NV_cooperative_matrix2)
347
+ // flash attention
348
+ for (const auto& f16acc : {false, true}) {
349
+ std::string acctype = f16acc ? "float16_t" : "float";
350
+
351
+ for (const auto& tname : type_names) {
352
+ if (tname == "f32") {
353
+ continue;
354
+ }
355
+
356
+ if (tname == "f16") {
357
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
358
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
359
+ } else {
360
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
361
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
362
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
363
+ }
364
+ }
365
+ }
366
+ #endif
367
+
368
  for (const auto& tname : type_names) {
369
  // mul mat vec
370
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);