Spaces:
Running
Running
Commit
·
d10b47b
1
Parent(s):
91607b6
vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (llama/10206)
Browse files- ggml/src/ggml-vulkan/ggml-vulkan.cpp +671 -80
- ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +2 -0
- ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +305 -0
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +289 -0
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +328 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +74 -21
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -167,6 +167,7 @@ struct vk_device_struct {
|
|
| 167 |
uint32_t subgroup_size;
|
| 168 |
uint32_t shader_core_count;
|
| 169 |
bool uma;
|
|
|
|
| 170 |
|
| 171 |
size_t idx;
|
| 172 |
|
|
@@ -176,6 +177,7 @@ struct vk_device_struct {
|
|
| 176 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
| 177 |
vk_pipeline pipeline_matmul_split_k_reduce;
|
| 178 |
|
|
|
|
| 179 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
| 180 |
|
| 181 |
vk_matmul_pipeline pipeline_matmul_id_f32;
|
|
@@ -229,6 +231,14 @@ struct vk_device_struct {
|
|
| 229 |
vk_pipeline pipeline_timestep_embedding_f32;
|
| 230 |
vk_pipeline pipeline_pool2d_f32;
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
| 233 |
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
| 234 |
|
|
@@ -340,6 +350,40 @@ struct vk_mat_vec_id_push_constants {
|
|
| 340 |
uint32_t nei0; uint32_t ne11;
|
| 341 |
};
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
struct vk_op_push_constants {
|
| 344 |
uint32_t KX;
|
| 345 |
uint32_t KY;
|
|
@@ -1265,6 +1309,23 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
| 1265 |
);
|
| 1266 |
}
|
| 1267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1268 |
static void ggml_vk_load_shaders(vk_device& device) {
|
| 1269 |
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
| 1270 |
|
|
@@ -1275,59 +1336,98 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1275 |
|
| 1276 |
// mulmat
|
| 1277 |
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1278 |
-
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq
|
|
|
|
|
|
|
| 1279 |
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
| 1280 |
-
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms
|
| 1281 |
-
|
|
|
|
| 1282 |
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1324 |
}
|
|
|
|
| 1325 |
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1331 |
}
|
| 1332 |
|
| 1333 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
@@ -1362,6 +1462,105 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1362 |
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
|
| 1363 |
};
|
| 1364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1365 |
if (device->fp16) {
|
| 1366 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1367 |
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
@@ -1648,15 +1847,28 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1648 |
device->physical_device = physical_devices[dev_num];
|
| 1649 |
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
| 1650 |
|
|
|
|
|
|
|
| 1651 |
bool maintenance4_support = false;
|
| 1652 |
bool sm_builtins = false;
|
|
|
|
|
|
|
| 1653 |
|
| 1654 |
// Check if maintenance4 is supported
|
| 1655 |
for (const auto& properties : ext_props) {
|
| 1656 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
| 1657 |
maintenance4_support = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1658 |
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
| 1659 |
sm_builtins = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1660 |
}
|
| 1661 |
}
|
| 1662 |
|
|
@@ -1679,6 +1891,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1679 |
last_struct = (VkBaseOutStructure *)&sm_props;
|
| 1680 |
}
|
| 1681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1682 |
device->physical_device.getProperties2(&props2);
|
| 1683 |
device->properties = props2.properties;
|
| 1684 |
|
|
@@ -1701,20 +1921,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1701 |
device->shader_core_count = 0;
|
| 1702 |
}
|
| 1703 |
|
| 1704 |
-
bool fp16_storage = false;
|
| 1705 |
-
bool fp16_compute = false;
|
| 1706 |
-
bool pipeline_robustness = false;
|
| 1707 |
-
|
| 1708 |
-
for (const auto& properties : ext_props) {
|
| 1709 |
-
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
| 1710 |
-
fp16_storage = true;
|
| 1711 |
-
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
| 1712 |
-
fp16_compute = true;
|
| 1713 |
-
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
| 1714 |
-
pipeline_robustness = true;
|
| 1715 |
-
}
|
| 1716 |
-
}
|
| 1717 |
-
|
| 1718 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 1719 |
const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 1720 |
|
|
@@ -1757,22 +1963,112 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1757 |
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
| 1758 |
vk11_features.pNext = &vk12_features;
|
| 1759 |
|
|
|
|
|
|
|
| 1760 |
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
|
| 1761 |
pl_robustness_features.pNext = nullptr;
|
| 1762 |
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
|
| 1763 |
pl_robustness_features.pipelineRobustness = VK_FALSE;
|
| 1764 |
|
| 1765 |
if (pipeline_robustness) {
|
| 1766 |
-
|
|
|
|
| 1767 |
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
| 1768 |
}
|
| 1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1770 |
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
| 1771 |
|
| 1772 |
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
| 1773 |
|
| 1774 |
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
| 1775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1776 |
if (!vk11_features.storageBuffer16BitAccess) {
|
| 1777 |
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
| 1778 |
throw std::runtime_error("Unsupported device");
|
|
@@ -2124,7 +2420,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
| 2124 |
return ctx->device->pipeline_dequant[type];
|
| 2125 |
}
|
| 2126 |
|
| 2127 |
-
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
|
| 2128 |
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
| 2129 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 2130 |
return ctx->device->pipeline_matmul_f32;
|
|
@@ -2132,14 +2428,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 2132 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 2133 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 2134 |
}
|
| 2135 |
-
if (
|
| 2136 |
-
|
| 2137 |
-
|
| 2138 |
-
|
| 2139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2140 |
}
|
| 2141 |
|
| 2142 |
-
if (src1_type != GGML_TYPE_F32) {
|
| 2143 |
return nullptr;
|
| 2144 |
}
|
| 2145 |
|
|
@@ -2160,6 +2465,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 2160 |
return nullptr;
|
| 2161 |
}
|
| 2162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2163 |
return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
| 2164 |
}
|
| 2165 |
|
|
@@ -2844,6 +3153,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
| 2844 |
break;
|
| 2845 |
}
|
| 2846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2847 |
if (m <= 32 || n <= 32) {
|
| 2848 |
return aligned ? mmp->a_s : mmp->s;
|
| 2849 |
}
|
|
@@ -3008,18 +3327,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 3008 |
}
|
| 3009 |
|
| 3010 |
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
| 3011 |
-
|
|
|
|
|
|
|
| 3012 |
|
| 3013 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 3014 |
|
| 3015 |
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
|
| 3016 |
|
| 3017 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3018 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3019 |
|
| 3020 |
if (qx_needs_dequant) {
|
| 3021 |
// Fall back to dequant + f16 mulmat
|
| 3022 |
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
|
| 3023 |
}
|
| 3024 |
|
| 3025 |
// Not implemented
|
|
@@ -3930,6 +4251,167 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 3930 |
}
|
| 3931 |
}
|
| 3932 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3933 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
| 3934 |
switch (op) {
|
| 3935 |
case GGML_OP_GET_ROWS:
|
|
@@ -5044,16 +5526,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
| 5044 |
ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
|
| 5045 |
|
| 5046 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
|
|
| 5047 |
for (size_t i = 0; i < num_it; i++) {
|
| 5048 |
-
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 5049 |
ggml_vk_matmul(
|
| 5050 |
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
| 5051 |
m, n, k,
|
| 5052 |
k, k, m, k*m, k*n, m*n,
|
| 5053 |
split_k, batch, batch, batch, 1, 1
|
| 5054 |
);
|
| 5055 |
-
ggml_vk_ctx_end(subctx);
|
| 5056 |
}
|
|
|
|
| 5057 |
|
| 5058 |
auto begin = std::chrono::high_resolution_clock::now();
|
| 5059 |
ggml_vk_submit(subctx, ctx->fence);
|
|
@@ -5391,16 +5873,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 5391 |
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
|
| 5392 |
|
| 5393 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
|
|
| 5394 |
for (size_t i = 0; i < num_it; i++) {
|
| 5395 |
-
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 5396 |
ggml_vk_matmul(
|
| 5397 |
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
| 5398 |
m, n, k,
|
| 5399 |
k, k, m, k*m, k*n, m*n,
|
| 5400 |
split_k, batch, batch, batch, 1, 1
|
| 5401 |
);
|
| 5402 |
-
ggml_vk_ctx_end(subctx);
|
| 5403 |
}
|
|
|
|
| 5404 |
|
| 5405 |
auto begin = std::chrono::high_resolution_clock::now();
|
| 5406 |
|
|
@@ -5621,7 +6103,8 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
| 5621 |
4096, 512, 11008,
|
| 5622 |
32000, 512, 4096,
|
| 5623 |
};
|
| 5624 |
-
const size_t num_it =
|
|
|
|
| 5625 |
for (size_t i = 0; i < vals.size(); i += 3) {
|
| 5626 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
| 5627 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
|
@@ -5676,6 +6159,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 5676 |
const ggml_tensor * src0 = node->src[0];
|
| 5677 |
const ggml_tensor * src1 = node->src[1];
|
| 5678 |
const ggml_tensor * src2 = node->src[2];
|
|
|
|
| 5679 |
|
| 5680 |
switch (node->op) {
|
| 5681 |
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
|
@@ -5728,6 +6212,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 5728 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 5729 |
case GGML_OP_POOL_2D:
|
| 5730 |
case GGML_OP_LEAKY_RELU:
|
|
|
|
| 5731 |
break;
|
| 5732 |
default:
|
| 5733 |
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
@@ -5920,6 +6405,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 5920 |
case GGML_OP_MUL_MAT_ID:
|
| 5921 |
ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
| 5922 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5923 |
break;
|
| 5924 |
default:
|
| 5925 |
return false;
|
|
@@ -6020,6 +6510,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 6020 |
break;
|
| 6021 |
case GGML_OP_MUL_MAT:
|
| 6022 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 6023 |
buf = tensor->buffer;
|
| 6024 |
|
| 6025 |
break;
|
|
@@ -6751,6 +7242,57 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 6751 |
|
| 6752 |
return true;
|
| 6753 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6754 |
case GGML_OP_GET_ROWS:
|
| 6755 |
{
|
| 6756 |
switch (op->src[0]->type) {
|
|
@@ -7065,6 +7607,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 7065 |
ggml_tensor * src0 = tensor->src[0];
|
| 7066 |
ggml_tensor * src1 = tensor->src[1];
|
| 7067 |
ggml_tensor * src2 = tensor->src[2];
|
|
|
|
| 7068 |
|
| 7069 |
struct ggml_init_params iparams = {
|
| 7070 |
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
|
|
@@ -7077,15 +7620,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 7077 |
struct ggml_tensor * src0_clone = nullptr;
|
| 7078 |
struct ggml_tensor * src1_clone = nullptr;
|
| 7079 |
struct ggml_tensor * src2_clone = nullptr;
|
|
|
|
| 7080 |
struct ggml_tensor * tensor_clone = nullptr;
|
| 7081 |
|
| 7082 |
size_t src0_size;
|
| 7083 |
size_t src1_size;
|
| 7084 |
size_t src2_size;
|
|
|
|
| 7085 |
|
| 7086 |
void * src0_buffer = nullptr;
|
| 7087 |
void * src1_buffer = nullptr;
|
| 7088 |
void * src2_buffer = nullptr;
|
|
|
|
| 7089 |
|
| 7090 |
if (src0 != nullptr) {
|
| 7091 |
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
|
|
@@ -7213,8 +7759,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 7213 |
ggml_vk_print_tensor(src2, "src2");
|
| 7214 |
}
|
| 7215 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7216 |
|
| 7217 |
-
if (tensor->op ==
|
|
|
|
|
|
|
|
|
|
| 7218 |
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
|
| 7219 |
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
| 7220 |
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
|
|
|
| 167 |
uint32_t subgroup_size;
|
| 168 |
uint32_t shader_core_count;
|
| 169 |
bool uma;
|
| 170 |
+
bool coopmat2;
|
| 171 |
|
| 172 |
size_t idx;
|
| 173 |
|
|
|
|
| 177 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
| 178 |
vk_pipeline pipeline_matmul_split_k_reduce;
|
| 179 |
|
| 180 |
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
| 181 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
| 182 |
|
| 183 |
vk_matmul_pipeline pipeline_matmul_id_f32;
|
|
|
|
| 231 |
vk_pipeline pipeline_timestep_embedding_f32;
|
| 232 |
vk_pipeline pipeline_pool2d_f32;
|
| 233 |
|
| 234 |
+
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
| 235 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
| 236 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
| 237 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
| 238 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
| 239 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
| 240 |
+
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
| 241 |
+
|
| 242 |
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
| 243 |
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
| 244 |
|
|
|
|
| 350 |
uint32_t nei0; uint32_t ne11;
|
| 351 |
};
|
| 352 |
|
| 353 |
+
struct vk_flash_attn_push_constants {
|
| 354 |
+
uint32_t N;
|
| 355 |
+
uint32_t KV;
|
| 356 |
+
|
| 357 |
+
uint32_t ne1;
|
| 358 |
+
uint32_t ne2;
|
| 359 |
+
uint32_t ne3;
|
| 360 |
+
|
| 361 |
+
uint32_t neq2;
|
| 362 |
+
uint32_t neq3;
|
| 363 |
+
uint32_t nek2;
|
| 364 |
+
uint32_t nek3;
|
| 365 |
+
uint32_t nev2;
|
| 366 |
+
uint32_t nev3;
|
| 367 |
+
uint32_t nem1;
|
| 368 |
+
|
| 369 |
+
uint32_t nb02;
|
| 370 |
+
uint32_t nb03;
|
| 371 |
+
uint32_t nb12;
|
| 372 |
+
uint32_t nb13;
|
| 373 |
+
uint32_t nb22;
|
| 374 |
+
uint32_t nb23;
|
| 375 |
+
uint32_t nb31;
|
| 376 |
+
|
| 377 |
+
float scale;
|
| 378 |
+
float max_bias;
|
| 379 |
+
float logit_softcap;
|
| 380 |
+
|
| 381 |
+
uint32_t mask;
|
| 382 |
+
uint32_t n_head_log2;
|
| 383 |
+
float m0;
|
| 384 |
+
float m1;
|
| 385 |
+
};
|
| 386 |
+
|
| 387 |
struct vk_op_push_constants {
|
| 388 |
uint32_t KX;
|
| 389 |
uint32_t KY;
|
|
|
|
| 1309 |
);
|
| 1310 |
}
|
| 1311 |
|
| 1312 |
+
// number of rows/cols for flash attention shader
|
| 1313 |
+
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
| 1314 |
+
static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
| 1315 |
+
GGML_UNUSED(clamp);
|
| 1316 |
+
|
| 1317 |
+
// small rows, large cols
|
| 1318 |
+
if (small_rows) {
|
| 1319 |
+
return {flash_attention_num_small_rows, 128};
|
| 1320 |
+
}
|
| 1321 |
+
// small cols to reduce register count
|
| 1322 |
+
if (ggml_is_quantized(type) || D == 256) {
|
| 1323 |
+
return {64, 32};
|
| 1324 |
+
}
|
| 1325 |
+
return {64, 64};
|
| 1326 |
+
};
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
static void ggml_vk_load_shaders(vk_device& device) {
|
| 1330 |
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
| 1331 |
|
|
|
|
| 1336 |
|
| 1337 |
// mulmat
|
| 1338 |
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1339 |
+
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
| 1340 |
+
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
| 1341 |
+
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
| 1342 |
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
| 1343 |
+
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
|
| 1344 |
+
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
| 1345 |
+
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
|
| 1346 |
|
| 1347 |
+
uint32_t l_align, m_align, s_align;
|
| 1348 |
+
if (device->coopmat2) {
|
| 1349 |
+
// spec constants and tile sizes for non-quant matmul/matmul_id
|
| 1350 |
+
l_warptile = { 256, 128, 256, 64 };
|
| 1351 |
+
m_warptile = { 256, 128, 128, 64 };
|
| 1352 |
+
s_warptile = { 128, 32, 16, 64 };
|
| 1353 |
+
l_wg_denoms = {128, 256, 1 };
|
| 1354 |
+
m_wg_denoms = {128, 128, 1 };
|
| 1355 |
+
s_wg_denoms = { 32, 16, 1 };
|
| 1356 |
+
|
| 1357 |
+
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
| 1358 |
+
l_warptile_mmq = { 256, 128, 256, 64 };
|
| 1359 |
+
m_warptile_mmq = { 256, 128, 128, 64 };
|
| 1360 |
+
s_warptile_mmq = { 256, 128, 128, 64 };
|
| 1361 |
+
l_mmq_wg_denoms = { 128, 256, 1 };
|
| 1362 |
+
m_mmq_wg_denoms = { 128, 128, 1 };
|
| 1363 |
+
s_mmq_wg_denoms = { 128, 128, 1 };
|
| 1364 |
+
|
| 1365 |
+
// spec constants and tile sizes for quant matmul (Qi_K)
|
| 1366 |
+
l_warptile_mmq_k = { 256, 128, 512, 16 };
|
| 1367 |
+
m_warptile_mmq_k = { 256, 128, 256, 16 };
|
| 1368 |
+
s_warptile_mmq_k = { 256, 32, 128, 64 };
|
| 1369 |
+
l_mmq_wg_denoms_k = { 128, 512, 1 };
|
| 1370 |
+
m_mmq_wg_denoms_k = { 128, 256, 1 };
|
| 1371 |
+
s_mmq_wg_denoms_k = { 32, 128, 1 };
|
| 1372 |
+
|
| 1373 |
+
// spec constants and tile sizes for quant matmul_id
|
| 1374 |
+
l_warptile_mmqid = { 256, 128, 128, 16 };
|
| 1375 |
+
m_warptile_mmqid = { 256, 128, 64, 16 };
|
| 1376 |
+
s_warptile_mmqid = { 256, 64, 64, 16 };
|
| 1377 |
+
l_mmqid_wg_denoms = { 128, 128, 1 };
|
| 1378 |
+
m_mmqid_wg_denoms = { 128, 64, 1 };
|
| 1379 |
+
s_mmqid_wg_denoms = { 64, 64, 1 };
|
| 1380 |
+
|
| 1381 |
+
l_align = 128;
|
| 1382 |
+
m_align = 64;
|
| 1383 |
+
s_align = 32;
|
| 1384 |
+
} else {
|
| 1385 |
+
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1386 |
+
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1387 |
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1388 |
+
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1389 |
+
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1390 |
+
s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1391 |
+
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1392 |
+
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1393 |
+
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
| 1394 |
+
l_align = 128;
|
| 1395 |
+
m_align = 64;
|
| 1396 |
+
s_align = 32;
|
| 1397 |
+
|
| 1398 |
+
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
|
| 1399 |
+
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
|
| 1400 |
+
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
|
| 1401 |
+
// But the numbers happen to work out for 32KB shared memory size that when using the medium
|
| 1402 |
+
// size there's enough room for everything, and we assert for this.
|
| 1403 |
+
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
| 1404 |
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
| 1405 |
+
l_warptile = m_warptile;
|
| 1406 |
+
l_wg_denoms = m_wg_denoms;
|
| 1407 |
+
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
| 1408 |
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1409 |
+
}
|
| 1410 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1411 |
+
// assert mul_mat_mat_id shaders will fit.
|
| 1412 |
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1413 |
}
|
| 1414 |
+
|
| 1415 |
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
| 1416 |
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
| 1417 |
+
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
|
| 1418 |
+
l_warptile_mmq = m_warptile_mmq;
|
| 1419 |
+
l_mmq_wg_denoms = m_mmq_wg_denoms;
|
| 1420 |
+
} else {
|
| 1421 |
+
l_warptile_mmq = s_warptile_mmq;
|
| 1422 |
+
l_mmq_wg_denoms = s_mmq_wg_denoms;
|
| 1423 |
+
}
|
| 1424 |
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
| 1425 |
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1426 |
+
}
|
| 1427 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1428 |
+
// assert mul_mat_mat_id shaders will fit.
|
| 1429 |
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1430 |
+
}
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
|
|
| 1462 |
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
|
| 1463 |
};
|
| 1464 |
|
| 1465 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 1466 |
+
if (device->coopmat2) {
|
| 1467 |
+
|
| 1468 |
+
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
| 1469 |
+
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
| 1470 |
+
};
|
| 1471 |
+
|
| 1472 |
+
auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
| 1473 |
+
// For large number of rows, 128 invocations seems to work best.
|
| 1474 |
+
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
| 1475 |
+
// can't use 256 for D==80.
|
| 1476 |
+
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
| 1477 |
+
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
| 1478 |
+
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
| 1479 |
+
};
|
| 1480 |
+
|
| 1481 |
+
#define CREATE_FA2(TYPE, NAMELC, D) \
|
| 1482 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
| 1483 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
| 1484 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
| 1485 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
| 1486 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
| 1487 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
| 1488 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
| 1489 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
| 1490 |
+
|
| 1491 |
+
#define CREATE_FA(TYPE, NAMELC) \
|
| 1492 |
+
CREATE_FA2(TYPE, NAMELC, 64) \
|
| 1493 |
+
CREATE_FA2(TYPE, NAMELC, 80) \
|
| 1494 |
+
CREATE_FA2(TYPE, NAMELC, 96) \
|
| 1495 |
+
CREATE_FA2(TYPE, NAMELC, 112) \
|
| 1496 |
+
CREATE_FA2(TYPE, NAMELC, 128) \
|
| 1497 |
+
CREATE_FA2(TYPE, NAMELC, 256)
|
| 1498 |
+
|
| 1499 |
+
CREATE_FA(GGML_TYPE_F16, f16)
|
| 1500 |
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0)
|
| 1501 |
+
CREATE_FA(GGML_TYPE_Q4_1, q4_1)
|
| 1502 |
+
CREATE_FA(GGML_TYPE_Q5_0, q5_0)
|
| 1503 |
+
CREATE_FA(GGML_TYPE_Q5_1, q5_1)
|
| 1504 |
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0)
|
| 1505 |
+
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
| 1506 |
+
//CREATE_FA(GGML_TYPE_Q2_K, q2_k)
|
| 1507 |
+
//CREATE_FA(GGML_TYPE_Q3_K, q3_k)
|
| 1508 |
+
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
| 1509 |
+
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
| 1510 |
+
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
| 1511 |
+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
| 1512 |
+
#undef CREATE_FA
|
| 1513 |
+
|
| 1514 |
+
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1515 |
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1516 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 1517 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 1518 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 1519 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
| 1520 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
| 1521 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 1522 |
+
|
| 1523 |
+
// Create 2 variants, {f16,f32} accumulator
|
| 1524 |
+
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1525 |
+
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1526 |
+
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1527 |
+
|
| 1528 |
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1529 |
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1530 |
+
|
| 1531 |
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1532 |
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
| 1533 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1534 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1535 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1536 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1537 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1538 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
| 1539 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
| 1540 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
| 1541 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
| 1542 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
| 1543 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1544 |
+
|
| 1545 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1546 |
+
CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1547 |
+
CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1548 |
+
|
| 1549 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1550 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1551 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1552 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1553 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1554 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1555 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1556 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1557 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1558 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1559 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1560 |
+
#undef CREATE_MM
|
| 1561 |
+
#undef CREATE_MM2
|
| 1562 |
+
} else
|
| 1563 |
+
#endif
|
| 1564 |
if (device->fp16) {
|
| 1565 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1566 |
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
|
|
| 1847 |
device->physical_device = physical_devices[dev_num];
|
| 1848 |
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
| 1849 |
|
| 1850 |
+
bool fp16_storage = false;
|
| 1851 |
+
bool fp16_compute = false;
|
| 1852 |
bool maintenance4_support = false;
|
| 1853 |
bool sm_builtins = false;
|
| 1854 |
+
bool pipeline_robustness = false;
|
| 1855 |
+
bool coopmat2_support = false;
|
| 1856 |
|
| 1857 |
// Check if maintenance4 is supported
|
| 1858 |
for (const auto& properties : ext_props) {
|
| 1859 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
| 1860 |
maintenance4_support = true;
|
| 1861 |
+
} else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
| 1862 |
+
fp16_storage = true;
|
| 1863 |
+
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
| 1864 |
+
fp16_compute = true;
|
| 1865 |
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
| 1866 |
sm_builtins = true;
|
| 1867 |
+
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
| 1868 |
+
pipeline_robustness = true;
|
| 1869 |
+
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 1870 |
+
!getenv("GGML_VULKAN_DISABLE_COOPMAT2")) {
|
| 1871 |
+
coopmat2_support = true;
|
| 1872 |
}
|
| 1873 |
}
|
| 1874 |
|
|
|
|
| 1891 |
last_struct = (VkBaseOutStructure *)&sm_props;
|
| 1892 |
}
|
| 1893 |
|
| 1894 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 1895 |
+
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
|
| 1896 |
+
if (coopmat2_support) {
|
| 1897 |
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
|
| 1898 |
+
last_struct = (VkBaseOutStructure *)&coopmat2_props;
|
| 1899 |
+
}
|
| 1900 |
+
#endif
|
| 1901 |
+
|
| 1902 |
device->physical_device.getProperties2(&props2);
|
| 1903 |
device->properties = props2.properties;
|
| 1904 |
|
|
|
|
| 1921 |
device->shader_core_count = 0;
|
| 1922 |
}
|
| 1923 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1924 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 1925 |
const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 1926 |
|
|
|
|
| 1963 |
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
| 1964 |
vk11_features.pNext = &vk12_features;
|
| 1965 |
|
| 1966 |
+
last_struct = (VkBaseOutStructure *)&vk12_features;
|
| 1967 |
+
|
| 1968 |
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
|
| 1969 |
pl_robustness_features.pNext = nullptr;
|
| 1970 |
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
|
| 1971 |
pl_robustness_features.pipelineRobustness = VK_FALSE;
|
| 1972 |
|
| 1973 |
if (pipeline_robustness) {
|
| 1974 |
+
last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
|
| 1975 |
+
last_struct = (VkBaseOutStructure *)&pl_robustness_features;
|
| 1976 |
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
| 1977 |
}
|
| 1978 |
|
| 1979 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 1980 |
+
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
|
| 1981 |
+
coopmat2_features.pNext = nullptr;
|
| 1982 |
+
coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
|
| 1983 |
+
if (coopmat2_support) {
|
| 1984 |
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
|
| 1985 |
+
last_struct = (VkBaseOutStructure *)&coopmat2_features;
|
| 1986 |
+
device_extensions.push_back("VK_NV_cooperative_matrix2");
|
| 1987 |
+
}
|
| 1988 |
+
#endif
|
| 1989 |
+
|
| 1990 |
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
| 1991 |
|
| 1992 |
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
| 1993 |
|
| 1994 |
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
| 1995 |
|
| 1996 |
+
if (coopmat2_support) {
|
| 1997 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 1998 |
+
if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
|
| 1999 |
+
coopmat2_features.cooperativeMatrixFlexibleDimensions &&
|
| 2000 |
+
coopmat2_features.cooperativeMatrixReductions &&
|
| 2001 |
+
coopmat2_features.cooperativeMatrixConversions &&
|
| 2002 |
+
coopmat2_features.cooperativeMatrixPerElementOperations &&
|
| 2003 |
+
coopmat2_features.cooperativeMatrixTensorAddressing &&
|
| 2004 |
+
coopmat2_features.cooperativeMatrixBlockLoads &&
|
| 2005 |
+
vk12_features.bufferDeviceAddress) {
|
| 2006 |
+
|
| 2007 |
+
std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
|
| 2008 |
+
uint32_t count = 0;
|
| 2009 |
+
|
| 2010 |
+
PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
|
| 2011 |
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
|
| 2012 |
+
(PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
|
| 2013 |
+
vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
|
| 2014 |
+
|
| 2015 |
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
|
| 2016 |
+
|
| 2017 |
+
VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
|
| 2018 |
+
empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
|
| 2019 |
+
flexible_dimensions.resize(count, empty_prop);
|
| 2020 |
+
|
| 2021 |
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
|
| 2022 |
+
|
| 2023 |
+
bool found_fp16_128 = false,
|
| 2024 |
+
found_fp16_256 = false,
|
| 2025 |
+
found_fp32_128 = false,
|
| 2026 |
+
found_fp32_256 = false;
|
| 2027 |
+
// need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
|
| 2028 |
+
// with 32x16x16 and 256 with 32x32x16.
|
| 2029 |
+
for (auto &prop : flexible_dimensions) {
|
| 2030 |
+
if (prop.saturatingAccumulation == VK_FALSE &&
|
| 2031 |
+
prop.scope == VK_SCOPE_WORKGROUP_KHR &&
|
| 2032 |
+
prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
| 2033 |
+
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
| 2034 |
+
|
| 2035 |
+
if (prop.workgroupInvocations == 128 &&
|
| 2036 |
+
prop.MGranularity <= 32 &&
|
| 2037 |
+
prop.NGranularity <= 16 &&
|
| 2038 |
+
prop.KGranularity <= 16) {
|
| 2039 |
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
| 2040 |
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
| 2041 |
+
found_fp16_128 = true;
|
| 2042 |
+
}
|
| 2043 |
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
| 2044 |
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
| 2045 |
+
found_fp32_128 = true;
|
| 2046 |
+
}
|
| 2047 |
+
}
|
| 2048 |
+
if (prop.workgroupInvocations == 256 &&
|
| 2049 |
+
prop.MGranularity <= 32 &&
|
| 2050 |
+
prop.NGranularity <= 32 &&
|
| 2051 |
+
prop.KGranularity <= 16) {
|
| 2052 |
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
| 2053 |
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
| 2054 |
+
found_fp16_256 = true;
|
| 2055 |
+
}
|
| 2056 |
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
| 2057 |
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
| 2058 |
+
found_fp32_256 = true;
|
| 2059 |
+
}
|
| 2060 |
+
}
|
| 2061 |
+
}
|
| 2062 |
+
}
|
| 2063 |
+
if (found_fp16_128 && found_fp16_256 &&
|
| 2064 |
+
found_fp32_128 && found_fp32_256 &&
|
| 2065 |
+
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
| 2066 |
+
device->coopmat2 = true;
|
| 2067 |
+
}
|
| 2068 |
+
}
|
| 2069 |
+
#endif
|
| 2070 |
+
}
|
| 2071 |
+
|
| 2072 |
if (!vk11_features.storageBuffer16BitAccess) {
|
| 2073 |
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
| 2074 |
throw std::runtime_error("Unsupported device");
|
|
|
|
| 2420 |
return ctx->device->pipeline_dequant[type];
|
| 2421 |
}
|
| 2422 |
|
| 2423 |
+
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
| 2424 |
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
| 2425 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 2426 |
return ctx->device->pipeline_matmul_f32;
|
|
|
|
| 2428 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 2429 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 2430 |
}
|
| 2431 |
+
if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) {
|
| 2432 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2433 |
+
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
| 2434 |
+
}
|
| 2435 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
| 2436 |
+
return ctx->device->pipeline_matmul_f16.f16acc;
|
| 2437 |
+
}
|
| 2438 |
+
} else {
|
| 2439 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2440 |
+
return ctx->device->pipeline_matmul_f16_f32.f32acc;
|
| 2441 |
+
}
|
| 2442 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
| 2443 |
+
return ctx->device->pipeline_matmul_f16.f32acc;
|
| 2444 |
+
}
|
| 2445 |
}
|
| 2446 |
|
| 2447 |
+
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
| 2448 |
return nullptr;
|
| 2449 |
}
|
| 2450 |
|
|
|
|
| 2465 |
return nullptr;
|
| 2466 |
}
|
| 2467 |
|
| 2468 |
+
if (ctx->device->coopmat2) {
|
| 2469 |
+
assert(src1_type == GGML_TYPE_F16);
|
| 2470 |
+
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
| 2471 |
+
}
|
| 2472 |
return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
| 2473 |
}
|
| 2474 |
|
|
|
|
| 3153 |
break;
|
| 3154 |
}
|
| 3155 |
|
| 3156 |
+
if (ctx->device->coopmat2) {
|
| 3157 |
+
if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
|
| 3158 |
+
return aligned ? mmp->a_l : mmp->l;
|
| 3159 |
+
}
|
| 3160 |
+
if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
|
| 3161 |
+
return aligned ? mmp->a_m : mmp->m;
|
| 3162 |
+
}
|
| 3163 |
+
return aligned ? mmp->a_s : mmp->s;
|
| 3164 |
+
}
|
| 3165 |
+
|
| 3166 |
if (m <= 32 || n <= 32) {
|
| 3167 |
return aligned ? mmp->a_s : mmp->s;
|
| 3168 |
}
|
|
|
|
| 3327 |
}
|
| 3328 |
|
| 3329 |
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
| 3330 |
+
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
|
| 3331 |
+
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
| 3332 |
+
!ggml_vk_dim01_contiguous(src1);
|
| 3333 |
|
| 3334 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 3335 |
|
| 3336 |
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
| 3337 |
|
| 3338 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3339 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3340 |
|
| 3341 |
if (qx_needs_dequant) {
|
| 3342 |
// Fall back to dequant + f16 mulmat
|
| 3343 |
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
| 3344 |
}
|
| 3345 |
|
| 3346 |
// Not implemented
|
|
|
|
| 4251 |
}
|
| 4252 |
}
|
| 4253 |
|
| 4254 |
+
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
| 4255 |
+
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
| 4256 |
+
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
| 4257 |
+
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
|
| 4258 |
+
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
| 4259 |
+
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
| 4260 |
+
|
| 4261 |
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 4262 |
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 4263 |
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
| 4264 |
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
| 4265 |
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
| 4266 |
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
| 4267 |
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
| 4268 |
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 4269 |
+
|
| 4270 |
+
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
| 4271 |
+
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
| 4272 |
+
|
| 4273 |
+
const uint32_t D = neq0;
|
| 4274 |
+
const uint32_t N = neq1;
|
| 4275 |
+
const uint32_t KV = nek1;
|
| 4276 |
+
|
| 4277 |
+
GGML_ASSERT(ne0 == D);
|
| 4278 |
+
GGML_ASSERT(ne2 == N);
|
| 4279 |
+
|
| 4280 |
+
// input tensor rows must be contiguous
|
| 4281 |
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
| 4282 |
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
| 4283 |
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
| 4284 |
+
|
| 4285 |
+
GGML_ASSERT(neq0 == D);
|
| 4286 |
+
GGML_ASSERT(nek0 == D);
|
| 4287 |
+
GGML_ASSERT(nev0 == D);
|
| 4288 |
+
|
| 4289 |
+
GGML_ASSERT(neq1 == N);
|
| 4290 |
+
GGML_ASSERT(nev0 == D);
|
| 4291 |
+
|
| 4292 |
+
GGML_ASSERT(nev1 == nek1);
|
| 4293 |
+
|
| 4294 |
+
// dst cannot be transposed or permuted
|
| 4295 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 4296 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 4297 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 4298 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 4299 |
+
|
| 4300 |
+
assert(dst->type == GGML_TYPE_F32);
|
| 4301 |
+
assert(q->type == GGML_TYPE_F32);
|
| 4302 |
+
assert(k->type == v->type);
|
| 4303 |
+
|
| 4304 |
+
vk_pipeline *pipelines;
|
| 4305 |
+
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
| 4306 |
+
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
| 4307 |
+
bool small_rows = N <= flash_attention_num_small_rows;
|
| 4308 |
+
switch (D) {
|
| 4309 |
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
| 4310 |
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
| 4311 |
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
| 4312 |
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
| 4313 |
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
| 4314 |
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
| 4315 |
+
default:
|
| 4316 |
+
assert(!"unsupported D value");
|
| 4317 |
+
return;
|
| 4318 |
+
}
|
| 4319 |
+
assert(pipelines);
|
| 4320 |
+
|
| 4321 |
+
bool aligned = (KV % pipelines[1]->align) == 0;
|
| 4322 |
+
vk_pipeline pipeline = pipelines[aligned];
|
| 4323 |
+
assert(pipeline);
|
| 4324 |
+
|
| 4325 |
+
if (dryrun) {
|
| 4326 |
+
// Request descriptor sets
|
| 4327 |
+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
| 4328 |
+
return;
|
| 4329 |
+
}
|
| 4330 |
+
|
| 4331 |
+
float scale = 1.0f;
|
| 4332 |
+
float max_bias = 0.0f;
|
| 4333 |
+
float logit_softcap = 0.0f;
|
| 4334 |
+
|
| 4335 |
+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
| 4336 |
+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
| 4337 |
+
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
| 4338 |
+
|
| 4339 |
+
if (logit_softcap != 0) {
|
| 4340 |
+
scale /= logit_softcap;
|
| 4341 |
+
}
|
| 4342 |
+
|
| 4343 |
+
const uint32_t n_head_kv = neq2;
|
| 4344 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 4345 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 4346 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 4347 |
+
|
| 4348 |
+
ggml_vk_sync_buffers(subctx);
|
| 4349 |
+
|
| 4350 |
+
vk_buffer d_Q, d_K, d_V, d_D, d_M;
|
| 4351 |
+
uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
|
| 4352 |
+
|
| 4353 |
+
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
|
| 4354 |
+
|
| 4355 |
+
if (ctx->device->uma) {
|
| 4356 |
+
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
| 4357 |
+
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
|
| 4358 |
+
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
|
| 4359 |
+
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
|
| 4360 |
+
Q_uma = d_Q != nullptr;
|
| 4361 |
+
K_uma = d_K != nullptr;
|
| 4362 |
+
V_uma = d_V != nullptr;
|
| 4363 |
+
D_uma = d_D != nullptr;
|
| 4364 |
+
if (mask) {
|
| 4365 |
+
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
|
| 4366 |
+
M_uma = d_M != nullptr;
|
| 4367 |
+
}
|
| 4368 |
+
}
|
| 4369 |
+
|
| 4370 |
+
|
| 4371 |
+
ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
| 4372 |
+
ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
|
| 4373 |
+
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
| 4374 |
+
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
| 4375 |
+
|
| 4376 |
+
if (!Q_uma) {
|
| 4377 |
+
d_Q = q_buf_ctx->dev_buffer;
|
| 4378 |
+
q_buf_offset = vk_tensor_offset(q) + q->view_offs;
|
| 4379 |
+
}
|
| 4380 |
+
if (!K_uma) {
|
| 4381 |
+
d_K = k_buf_ctx->dev_buffer;
|
| 4382 |
+
k_buf_offset = vk_tensor_offset(k) + k->view_offs;
|
| 4383 |
+
}
|
| 4384 |
+
if (!V_uma) {
|
| 4385 |
+
d_V = v_buf_ctx->dev_buffer;
|
| 4386 |
+
v_buf_offset = vk_tensor_offset(v) + v->view_offs;
|
| 4387 |
+
}
|
| 4388 |
+
if (!D_uma) {
|
| 4389 |
+
d_D = d_buf_ctx->dev_buffer;
|
| 4390 |
+
d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
| 4391 |
+
}
|
| 4392 |
+
|
| 4393 |
+
if (!M_uma) {
|
| 4394 |
+
d_M = d_Q;
|
| 4395 |
+
m_buf_offset = q_buf_offset;
|
| 4396 |
+
if (mask) {
|
| 4397 |
+
ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
|
| 4398 |
+
d_M = m_buf_ctx->dev_buffer;
|
| 4399 |
+
m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
|
| 4400 |
+
}
|
| 4401 |
+
}
|
| 4402 |
+
|
| 4403 |
+
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
|
| 4404 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 4405 |
+
{
|
| 4406 |
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
| 4407 |
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 4408 |
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 4409 |
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 4410 |
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 4411 |
+
},
|
| 4412 |
+
sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
|
| 4413 |
+
}
|
| 4414 |
+
|
| 4415 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
| 4416 |
switch (op) {
|
| 4417 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 5526 |
ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
|
| 5527 |
|
| 5528 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
| 5529 |
+
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 5530 |
for (size_t i = 0; i < num_it; i++) {
|
|
|
|
| 5531 |
ggml_vk_matmul(
|
| 5532 |
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
| 5533 |
m, n, k,
|
| 5534 |
k, k, m, k*m, k*n, m*n,
|
| 5535 |
split_k, batch, batch, batch, 1, 1
|
| 5536 |
);
|
|
|
|
| 5537 |
}
|
| 5538 |
+
ggml_vk_ctx_end(subctx);
|
| 5539 |
|
| 5540 |
auto begin = std::chrono::high_resolution_clock::now();
|
| 5541 |
ggml_vk_submit(subctx, ctx->fence);
|
|
|
|
| 5873 |
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
|
| 5874 |
|
| 5875 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
| 5876 |
+
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 5877 |
for (size_t i = 0; i < num_it; i++) {
|
|
|
|
| 5878 |
ggml_vk_matmul(
|
| 5879 |
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
| 5880 |
m, n, k,
|
| 5881 |
k, k, m, k*m, k*n, m*n,
|
| 5882 |
split_k, batch, batch, batch, 1, 1
|
| 5883 |
);
|
|
|
|
| 5884 |
}
|
| 5885 |
+
ggml_vk_ctx_end(subctx);
|
| 5886 |
|
| 5887 |
auto begin = std::chrono::high_resolution_clock::now();
|
| 5888 |
|
|
|
|
| 6103 |
4096, 512, 11008,
|
| 6104 |
32000, 512, 4096,
|
| 6105 |
};
|
| 6106 |
+
const size_t num_it = 100;
|
| 6107 |
+
|
| 6108 |
for (size_t i = 0; i < vals.size(); i += 3) {
|
| 6109 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
| 6110 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
|
|
|
| 6159 |
const ggml_tensor * src0 = node->src[0];
|
| 6160 |
const ggml_tensor * src1 = node->src[1];
|
| 6161 |
const ggml_tensor * src2 = node->src[2];
|
| 6162 |
+
const ggml_tensor * src3 = node->src[3];
|
| 6163 |
|
| 6164 |
switch (node->op) {
|
| 6165 |
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
|
|
|
| 6212 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 6213 |
case GGML_OP_POOL_2D:
|
| 6214 |
case GGML_OP_LEAKY_RELU:
|
| 6215 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 6216 |
break;
|
| 6217 |
default:
|
| 6218 |
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
|
|
| 6405 |
case GGML_OP_MUL_MAT_ID:
|
| 6406 |
ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
| 6407 |
|
| 6408 |
+
break;
|
| 6409 |
+
|
| 6410 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 6411 |
+
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
| 6412 |
+
|
| 6413 |
break;
|
| 6414 |
default:
|
| 6415 |
return false;
|
|
|
|
| 6510 |
break;
|
| 6511 |
case GGML_OP_MUL_MAT:
|
| 6512 |
case GGML_OP_MUL_MAT_ID:
|
| 6513 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 6514 |
buf = tensor->buffer;
|
| 6515 |
|
| 6516 |
break;
|
|
|
|
| 7242 |
|
| 7243 |
return true;
|
| 7244 |
} break;
|
| 7245 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 7246 |
+
{
|
| 7247 |
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
| 7248 |
+
if (!ggml_vk_get_device(ctx->device)->coopmat2) {
|
| 7249 |
+
return false;
|
| 7250 |
+
}
|
| 7251 |
+
switch (op->src[0]->ne[0]) {
|
| 7252 |
+
case 64:
|
| 7253 |
+
case 80:
|
| 7254 |
+
case 96:
|
| 7255 |
+
case 112:
|
| 7256 |
+
case 128:
|
| 7257 |
+
case 256:
|
| 7258 |
+
break;
|
| 7259 |
+
default:
|
| 7260 |
+
return false;
|
| 7261 |
+
}
|
| 7262 |
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 7263 |
+
return false;
|
| 7264 |
+
}
|
| 7265 |
+
if (op->type != GGML_TYPE_F32) {
|
| 7266 |
+
return false;
|
| 7267 |
+
}
|
| 7268 |
+
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
| 7269 |
+
return false;
|
| 7270 |
+
}
|
| 7271 |
+
// It's straightforward to support different K/V dequant, but would
|
| 7272 |
+
// significantly increase the number of pipelines
|
| 7273 |
+
if (op->src[1]->type != op->src[2]->type) {
|
| 7274 |
+
return false;
|
| 7275 |
+
}
|
| 7276 |
+
switch (op->src[1]->type) {
|
| 7277 |
+
case GGML_TYPE_F16:
|
| 7278 |
+
case GGML_TYPE_Q4_0:
|
| 7279 |
+
case GGML_TYPE_Q4_1:
|
| 7280 |
+
case GGML_TYPE_Q5_0:
|
| 7281 |
+
case GGML_TYPE_Q5_1:
|
| 7282 |
+
case GGML_TYPE_Q8_0:
|
| 7283 |
+
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
| 7284 |
+
//case GGML_TYPE_Q2_K:
|
| 7285 |
+
//case GGML_TYPE_Q3_K:
|
| 7286 |
+
//case GGML_TYPE_Q4_K:
|
| 7287 |
+
//case GGML_TYPE_Q5_K:
|
| 7288 |
+
//case GGML_TYPE_Q6_K:
|
| 7289 |
+
case GGML_TYPE_IQ4_NL:
|
| 7290 |
+
break;
|
| 7291 |
+
default:
|
| 7292 |
+
return false;
|
| 7293 |
+
}
|
| 7294 |
+
return true;
|
| 7295 |
+
}
|
| 7296 |
case GGML_OP_GET_ROWS:
|
| 7297 |
{
|
| 7298 |
switch (op->src[0]->type) {
|
|
|
|
| 7607 |
ggml_tensor * src0 = tensor->src[0];
|
| 7608 |
ggml_tensor * src1 = tensor->src[1];
|
| 7609 |
ggml_tensor * src2 = tensor->src[2];
|
| 7610 |
+
ggml_tensor * src3 = tensor->src[3];
|
| 7611 |
|
| 7612 |
struct ggml_init_params iparams = {
|
| 7613 |
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
|
|
|
|
| 7620 |
struct ggml_tensor * src0_clone = nullptr;
|
| 7621 |
struct ggml_tensor * src1_clone = nullptr;
|
| 7622 |
struct ggml_tensor * src2_clone = nullptr;
|
| 7623 |
+
struct ggml_tensor * src3_clone = nullptr;
|
| 7624 |
struct ggml_tensor * tensor_clone = nullptr;
|
| 7625 |
|
| 7626 |
size_t src0_size;
|
| 7627 |
size_t src1_size;
|
| 7628 |
size_t src2_size;
|
| 7629 |
+
size_t src3_size;
|
| 7630 |
|
| 7631 |
void * src0_buffer = nullptr;
|
| 7632 |
void * src1_buffer = nullptr;
|
| 7633 |
void * src2_buffer = nullptr;
|
| 7634 |
+
void * src3_buffer = nullptr;
|
| 7635 |
|
| 7636 |
if (src0 != nullptr) {
|
| 7637 |
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
|
|
|
|
| 7759 |
ggml_vk_print_tensor(src2, "src2");
|
| 7760 |
}
|
| 7761 |
}
|
| 7762 |
+
if (src3 != nullptr) {
|
| 7763 |
+
src3_clone = ggml_dup_tensor(ggml_ctx, src3);
|
| 7764 |
+
|
| 7765 |
+
src3_size = ggml_nbytes(src3);
|
| 7766 |
+
|
| 7767 |
+
src3_buffer = malloc(src3_size);
|
| 7768 |
+
src3_clone->data = src3_buffer;
|
| 7769 |
+
if (ggml_backend_buffer_is_host(src3->buffer)) {
|
| 7770 |
+
memcpy(src3_clone->data, src3->data, src3_size);
|
| 7771 |
+
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
| 7772 |
+
} else if (ggml_backend_buffer_is_vk(src3->buffer)) {
|
| 7773 |
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
|
| 7774 |
+
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
| 7775 |
+
uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
|
| 7776 |
+
if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
|
| 7777 |
+
for (int i3 = 0; i3 < src3->ne[3]; i3++) {
|
| 7778 |
+
for (int i2 = 0; i2 < src3->ne[2]; i2++) {
|
| 7779 |
+
const int idx = i3*src3->ne[2] + i2;
|
| 7780 |
+
ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
|
| 7781 |
+
}
|
| 7782 |
+
}
|
| 7783 |
+
|
| 7784 |
+
src3_clone->nb[0] = src3->nb[0];
|
| 7785 |
+
src3_clone->nb[1] = src3->nb[1];
|
| 7786 |
+
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
| 7787 |
+
src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
|
| 7788 |
+
}
|
| 7789 |
+
} else {
|
| 7790 |
+
if (offset + src3_size >= buffer_gpu->size) {
|
| 7791 |
+
src3_size = buffer_gpu->size - offset;
|
| 7792 |
+
}
|
| 7793 |
+
ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
|
| 7794 |
+
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
| 7795 |
+
}
|
| 7796 |
+
} else {
|
| 7797 |
+
GGML_ABORT("fatal error");
|
| 7798 |
+
}
|
| 7799 |
+
|
| 7800 |
+
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
| 7801 |
+
ggml_vk_print_tensor(src3, "src3");
|
| 7802 |
+
}
|
| 7803 |
+
}
|
| 7804 |
|
| 7805 |
+
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
| 7806 |
+
const float *params = (const float *)tensor->op_params;
|
| 7807 |
+
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
|
| 7808 |
+
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
| 7809 |
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
|
| 7810 |
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
| 7811 |
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
find_package (Threads REQUIRED)
|
|
|
|
| 2 |
|
| 3 |
set(TARGET vulkan-shaders-gen)
|
| 4 |
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
| 5 |
install(TARGETS ${TARGET} RUNTIME)
|
| 6 |
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
| 7 |
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
|
|
|
|
|
| 1 |
find_package (Threads REQUIRED)
|
| 2 |
+
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
| 3 |
|
| 4 |
set(TARGET vulkan-shaders-gen)
|
| 5 |
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
| 6 |
install(TARGETS ${TARGET} RUNTIME)
|
| 7 |
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
| 8 |
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
| 9 |
+
target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
|
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include "types.comp"
|
| 3 |
+
|
| 4 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
| 5 |
+
block_q4_0_packed16 block;
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 9 |
+
{
|
| 10 |
+
const float16_t d = bl.block.d;
|
| 11 |
+
const uint idx = coordInBlock[1];
|
| 12 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 13 |
+
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
|
| 14 |
+
qs >>= shift;
|
| 15 |
+
qs &= 0xF;
|
| 16 |
+
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
|
| 17 |
+
return ret;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
|
| 21 |
+
block_q4_1 block;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 25 |
+
{
|
| 26 |
+
const float16_t d = bl.block.d;
|
| 27 |
+
const float16_t m = bl.block.m;
|
| 28 |
+
const uint idx = coordInBlock[1];
|
| 29 |
+
const uint iqs = idx & 0xF;
|
| 30 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 31 |
+
uint32_t qs = bl.block.qs[iqs];
|
| 32 |
+
qs >>= shift;
|
| 33 |
+
qs &= 0xF;
|
| 34 |
+
float16_t ret = float16_t(qs) * d + m;
|
| 35 |
+
return ret;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
|
| 39 |
+
block_q5_0 block;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 43 |
+
{
|
| 44 |
+
const float16_t d = bl.block.d;
|
| 45 |
+
const uint idx = coordInBlock[1];
|
| 46 |
+
const uint iqs = idx & 0xF;
|
| 47 |
+
|
| 48 |
+
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
|
| 49 |
+
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
|
| 50 |
+
|
| 51 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 52 |
+
uint32_t qs = bl.block.qs[iqs];
|
| 53 |
+
qs >>= shift;
|
| 54 |
+
qs &= 0xF;
|
| 55 |
+
|
| 56 |
+
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
|
| 57 |
+
return ret;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
|
| 61 |
+
block_q5_1 block;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 65 |
+
{
|
| 66 |
+
const float16_t d = bl.block.d;
|
| 67 |
+
const float16_t m = bl.block.m;
|
| 68 |
+
const uint idx = coordInBlock[1];
|
| 69 |
+
const uint iqs = idx & 0xF;
|
| 70 |
+
|
| 71 |
+
const uint uint_qh = bl.block.qh;
|
| 72 |
+
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
|
| 73 |
+
|
| 74 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 75 |
+
uint32_t qs = bl.block.qs[iqs];
|
| 76 |
+
qs >>= shift;
|
| 77 |
+
qs &= 0xF;
|
| 78 |
+
|
| 79 |
+
float16_t ret = float16_t(qs | qh) * d + m;
|
| 80 |
+
return ret;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
|
| 84 |
+
block_q8_0_packed16 block;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 88 |
+
{
|
| 89 |
+
const float16_t d = bl.block.d;
|
| 90 |
+
const uint idx = coordInBlock[1];
|
| 91 |
+
const uint iqs = idx;
|
| 92 |
+
|
| 93 |
+
// Load 16b and select the byte for this element
|
| 94 |
+
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
|
| 95 |
+
float16_t ret = float16_t(qs) * d;
|
| 96 |
+
return ret;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
|
| 100 |
+
block_q2_K block;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 104 |
+
{
|
| 105 |
+
const f16vec2 d = bl.block.d;
|
| 106 |
+
const uint idx = coordInBlock[1];
|
| 107 |
+
const uint iqs = idx;
|
| 108 |
+
|
| 109 |
+
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
|
| 110 |
+
const uint scalesi = iqs / 16; // 0..15
|
| 111 |
+
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
|
| 112 |
+
|
| 113 |
+
uint32_t qs = bl.block.qs[qsi];
|
| 114 |
+
const uint scales = bl.block.scales[scalesi];
|
| 115 |
+
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
|
| 116 |
+
return ret;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
|
| 120 |
+
block_q3_K block;
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 124 |
+
{
|
| 125 |
+
const uint idx = coordInBlock[1];
|
| 126 |
+
const uint iqs = idx;
|
| 127 |
+
|
| 128 |
+
const uint n = iqs / 128; // 0,1
|
| 129 |
+
const uint qsi = n * 32 + (iqs % 32); // 0..63
|
| 130 |
+
const uint hmi = (iqs % 32); // 0..31
|
| 131 |
+
const uint j = (iqs % 128) / 8; // 0..15
|
| 132 |
+
const uint is = iqs / 16; // 0..15
|
| 133 |
+
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
|
| 134 |
+
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
| 135 |
+
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
| 136 |
+
|
| 137 |
+
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
|
| 138 |
+
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
|
| 139 |
+
uint32_t scaleidx1 = is + 8 - (is/4)*4;
|
| 140 |
+
uint32_t scaleidx1shift = (is/4)*2;
|
| 141 |
+
|
| 142 |
+
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
|
| 143 |
+
|
| 144 |
+
const float16_t dl = bl.block.d * float16_t(us - 32);
|
| 145 |
+
|
| 146 |
+
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
|
| 147 |
+
|
| 148 |
+
return ret;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
|
| 152 |
+
block_q4_K block;
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 156 |
+
{
|
| 157 |
+
const uint idx = coordInBlock[1];
|
| 158 |
+
const uint iqs = idx;
|
| 159 |
+
|
| 160 |
+
const uint n = iqs / 64; // 0,1,2,3
|
| 161 |
+
const uint b = (iqs % 64) / 32; // 0,1
|
| 162 |
+
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 163 |
+
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
| 164 |
+
|
| 165 |
+
const f16vec2 loadd = bl.block.d;
|
| 166 |
+
|
| 167 |
+
uint32_t sc;
|
| 168 |
+
uint32_t mbyte;
|
| 169 |
+
|
| 170 |
+
uint32_t scidx0 = (is < 4) ? is : (is + 4);
|
| 171 |
+
uint32_t scidx1 = (is < 4) ? is : (is - 4);
|
| 172 |
+
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 173 |
+
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
| 174 |
+
uint32_t mbidx0 = is + 4;
|
| 175 |
+
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
| 176 |
+
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 177 |
+
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
| 178 |
+
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 179 |
+
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
| 180 |
+
|
| 181 |
+
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
|
| 182 |
+
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
| 183 |
+
|
| 184 |
+
const float16_t d = loadd.x * float16_t(sc);
|
| 185 |
+
const float16_t m = loadd.y * float16_t(mbyte);
|
| 186 |
+
|
| 187 |
+
uint32_t dmask = 0xF << (b * 4);
|
| 188 |
+
|
| 189 |
+
float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
|
| 190 |
+
|
| 191 |
+
return ret;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
| 195 |
+
block_q5_K block;
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 199 |
+
{
|
| 200 |
+
const uint idx = coordInBlock[1];
|
| 201 |
+
const uint iqs = idx;
|
| 202 |
+
|
| 203 |
+
const uint n = iqs / 64; // 0,1,2,3
|
| 204 |
+
const uint b = (iqs % 64) / 32; // 0,1
|
| 205 |
+
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 206 |
+
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
| 207 |
+
const uint qhi = (iqs % 32); // 0..31
|
| 208 |
+
|
| 209 |
+
const uint8_t hm = uint8_t(1 << (iqs / 32));
|
| 210 |
+
|
| 211 |
+
const f16vec2 loadd = bl.block.d;
|
| 212 |
+
|
| 213 |
+
uint32_t sc;
|
| 214 |
+
uint32_t mbyte;
|
| 215 |
+
|
| 216 |
+
uint32_t scidx0 = (is < 4) ? is : (is + 4);
|
| 217 |
+
uint32_t scidx1 = (is < 4) ? is : (is - 4);
|
| 218 |
+
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 219 |
+
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
| 220 |
+
uint32_t mbidx0 = is + 4;
|
| 221 |
+
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
| 222 |
+
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 223 |
+
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
| 224 |
+
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 225 |
+
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
| 226 |
+
|
| 227 |
+
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
|
| 228 |
+
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
| 229 |
+
|
| 230 |
+
const float16_t d = loadd.x * float16_t(sc);
|
| 231 |
+
const float16_t m = loadd.y * float16_t(mbyte);
|
| 232 |
+
|
| 233 |
+
uint32_t dmask = 0xF << (b * 4);
|
| 234 |
+
|
| 235 |
+
float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
|
| 236 |
+
|
| 237 |
+
return ret;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
| 241 |
+
block_q6_K block;
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 245 |
+
{
|
| 246 |
+
const uint idx = coordInBlock[1];
|
| 247 |
+
const uint iqs = idx;
|
| 248 |
+
|
| 249 |
+
const uint n = iqs / 128; // 0,1
|
| 250 |
+
const uint b = (iqs % 128) / 64; // 0,1
|
| 251 |
+
const uint is_b = (iqs % 32) / 16; // 0,1
|
| 252 |
+
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
|
| 253 |
+
const uint is = 8 * n + qhshift + is_b; // 0..15
|
| 254 |
+
const uint qsi = n * 64 + (iqs % 64); // 0..127
|
| 255 |
+
const uint qhi = n * 32 + (iqs % 32); // 0..63
|
| 256 |
+
|
| 257 |
+
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
| 258 |
+
|
| 259 |
+
float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
|
| 260 |
+
|
| 261 |
+
return ret;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
#if defined(DATA_A_IQ4_NL)
|
| 265 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
|
| 266 |
+
block_iq4_nl block;
|
| 267 |
+
};
|
| 268 |
+
|
| 269 |
+
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 270 |
+
{
|
| 271 |
+
const float16_t d = bl.block.d;
|
| 272 |
+
const uint idx = coordInBlock[1];
|
| 273 |
+
const uint iqs = idx & 0xF;
|
| 274 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 275 |
+
uint32_t qs = bl.block.qs[iqs];
|
| 276 |
+
qs >>= shift;
|
| 277 |
+
qs &= 0xF;
|
| 278 |
+
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
|
| 279 |
+
return ret;
|
| 280 |
+
}
|
| 281 |
+
#endif
|
| 282 |
+
|
| 283 |
+
#if defined(DATA_A_Q4_0)
|
| 284 |
+
#define dequantFuncA dequantFuncQ4_0
|
| 285 |
+
#elif defined(DATA_A_Q4_1)
|
| 286 |
+
#define dequantFuncA dequantFuncQ4_1
|
| 287 |
+
#elif defined(DATA_A_Q5_0)
|
| 288 |
+
#define dequantFuncA dequantFuncQ5_0
|
| 289 |
+
#elif defined(DATA_A_Q5_1)
|
| 290 |
+
#define dequantFuncA dequantFuncQ5_1
|
| 291 |
+
#elif defined(DATA_A_Q8_0)
|
| 292 |
+
#define dequantFuncA dequantFuncQ8_0
|
| 293 |
+
#elif defined(DATA_A_Q2_K)
|
| 294 |
+
#define dequantFuncA dequantFuncQ2_K
|
| 295 |
+
#elif defined(DATA_A_Q3_K)
|
| 296 |
+
#define dequantFuncA dequantFuncQ3_K
|
| 297 |
+
#elif defined(DATA_A_Q4_K)
|
| 298 |
+
#define dequantFuncA dequantFuncQ4_K
|
| 299 |
+
#elif defined(DATA_A_Q5_K)
|
| 300 |
+
#define dequantFuncA dequantFuncQ5_K
|
| 301 |
+
#elif defined(DATA_A_Q6_K)
|
| 302 |
+
#define dequantFuncA dequantFuncQ6_K
|
| 303 |
+
#elif defined(DATA_A_IQ4_NL)
|
| 304 |
+
#define dequantFuncA dequantFuncIQ4_NL
|
| 305 |
+
#endif
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 4 |
+
#extension GL_EXT_shader_16bit_storage : require
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 7 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
| 8 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
| 9 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 10 |
+
|
| 11 |
+
#extension GL_KHR_memory_scope_semantics : enable
|
| 12 |
+
#extension GL_KHR_cooperative_matrix : enable
|
| 13 |
+
#extension GL_NV_cooperative_matrix2 : enable
|
| 14 |
+
#extension GL_EXT_buffer_reference : enable
|
| 15 |
+
#extension GL_KHR_shader_subgroup_ballot : enable
|
| 16 |
+
#extension GL_KHR_shader_subgroup_vote : enable
|
| 17 |
+
#extension GL_EXT_null_initializer : enable
|
| 18 |
+
|
| 19 |
+
#include "types.comp"
|
| 20 |
+
#include "dequant_funcs_cm2.comp"
|
| 21 |
+
|
| 22 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 23 |
+
|
| 24 |
+
layout (constant_id = 1) const uint32_t Br = 32;
|
| 25 |
+
layout (constant_id = 2) const uint32_t Bc = 32;
|
| 26 |
+
layout (constant_id = 3) const uint32_t D = 32;
|
| 27 |
+
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
| 28 |
+
|
| 29 |
+
layout (push_constant) uniform parameter {
|
| 30 |
+
uint32_t N;
|
| 31 |
+
uint32_t KV;
|
| 32 |
+
|
| 33 |
+
uint32_t ne1;
|
| 34 |
+
uint32_t ne2;
|
| 35 |
+
uint32_t ne3;
|
| 36 |
+
|
| 37 |
+
uint32_t neq2;
|
| 38 |
+
uint32_t neq3;
|
| 39 |
+
uint32_t nek2;
|
| 40 |
+
uint32_t nek3;
|
| 41 |
+
uint32_t nev2;
|
| 42 |
+
uint32_t nev3;
|
| 43 |
+
uint32_t nem1;
|
| 44 |
+
|
| 45 |
+
uint32_t nb02;
|
| 46 |
+
uint32_t nb03;
|
| 47 |
+
uint32_t nb12;
|
| 48 |
+
uint32_t nb13;
|
| 49 |
+
uint32_t nb22;
|
| 50 |
+
uint32_t nb23;
|
| 51 |
+
uint32_t nb31;
|
| 52 |
+
|
| 53 |
+
float scale;
|
| 54 |
+
float max_bias;
|
| 55 |
+
float logit_softcap;
|
| 56 |
+
|
| 57 |
+
uint32_t mask;
|
| 58 |
+
uint32_t n_head_log2;
|
| 59 |
+
float m0;
|
| 60 |
+
float m1;
|
| 61 |
+
} p;
|
| 62 |
+
|
| 63 |
+
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
| 64 |
+
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
| 65 |
+
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
| 66 |
+
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
| 67 |
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
| 68 |
+
|
| 69 |
+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
| 70 |
+
|
| 71 |
+
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
| 72 |
+
return max(x, y);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
| 76 |
+
return x;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// Replace matrix elements >= numRows or numCols with 'replace'
|
| 80 |
+
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
|
| 81 |
+
if (row >= numRows || col >= numCols) {
|
| 82 |
+
return replace;
|
| 83 |
+
}
|
| 84 |
+
return elem;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
|
| 88 |
+
{
|
| 89 |
+
return exp(elem);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
|
| 93 |
+
{
|
| 94 |
+
return max(elem0, elem1);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#if defined(BLOCK_SIZE)
|
| 98 |
+
#define DECODEFUNC , DEQUANTFUNC
|
| 99 |
+
#else
|
| 100 |
+
#define DECODEFUNC
|
| 101 |
+
#endif
|
| 102 |
+
|
| 103 |
+
void main() {
|
| 104 |
+
#if defined(DATA_A_IQ4_NL)
|
| 105 |
+
init_iq4nl_shmem();
|
| 106 |
+
#endif
|
| 107 |
+
|
| 108 |
+
const uint32_t N = p.N;
|
| 109 |
+
const uint32_t KV = p.KV;
|
| 110 |
+
|
| 111 |
+
const uint32_t Tr = CEIL_DIV(N, Br);
|
| 112 |
+
const uint32_t Tc = CEIL_DIV(KV, Bc);
|
| 113 |
+
|
| 114 |
+
const uint32_t i = gl_WorkGroupID.x;
|
| 115 |
+
|
| 116 |
+
const uint32_t iq2 = gl_WorkGroupID.y;
|
| 117 |
+
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 118 |
+
|
| 119 |
+
// broadcast factors
|
| 120 |
+
const uint32_t rk2 = p.neq2/p.nek2;
|
| 121 |
+
const uint32_t rk3 = p.neq3/p.nek3;
|
| 122 |
+
|
| 123 |
+
const uint32_t rv2 = p.neq2/p.nev2;
|
| 124 |
+
const uint32_t rv3 = p.neq3/p.nev3;
|
| 125 |
+
|
| 126 |
+
// k indices
|
| 127 |
+
const uint32_t ik3 = iq3 / rk3;
|
| 128 |
+
const uint32_t ik2 = iq2 / rk2;
|
| 129 |
+
|
| 130 |
+
// v indices
|
| 131 |
+
const uint32_t iv3 = iq3 / rv3;
|
| 132 |
+
const uint32_t iv2 = iq2 / rv2;
|
| 133 |
+
|
| 134 |
+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 135 |
+
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
| 136 |
+
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
|
| 137 |
+
|
| 138 |
+
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
| 139 |
+
|
| 140 |
+
#if defined(BLOCK_SIZE)
|
| 141 |
+
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
| 142 |
+
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
| 143 |
+
#endif
|
| 144 |
+
|
| 145 |
+
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
| 146 |
+
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 147 |
+
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 148 |
+
|
| 149 |
+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
| 150 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
| 151 |
+
|
| 152 |
+
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
| 153 |
+
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
| 154 |
+
|
| 155 |
+
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
| 156 |
+
Qf16 *= float16_t(p.scale);
|
| 157 |
+
|
| 158 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
| 159 |
+
|
| 160 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
| 161 |
+
|
| 162 |
+
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 163 |
+
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
|
| 164 |
+
|
| 165 |
+
ACC_TYPE slope = ACC_TYPE(1.0);
|
| 166 |
+
|
| 167 |
+
// ALiBi
|
| 168 |
+
if (p.max_bias > 0.0f) {
|
| 169 |
+
const uint32_t h = iq2;
|
| 170 |
+
|
| 171 |
+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 172 |
+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 173 |
+
|
| 174 |
+
slope = pow(base, ACC_TYPE(exph));
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
[[dont_unroll]]
|
| 178 |
+
for (uint32_t j = 0; j < Tc; ++j) {
|
| 179 |
+
|
| 180 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 181 |
+
|
| 182 |
+
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
| 183 |
+
|
| 184 |
+
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
| 185 |
+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
| 186 |
+
S = coopMatMulAdd(Qf16, K_T, S);
|
| 187 |
+
|
| 188 |
+
if (p.logit_softcap != 0.0f) {
|
| 189 |
+
[[unroll]]
|
| 190 |
+
for (int k = 0; k < S.length(); ++k) {
|
| 191 |
+
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if (p.mask != 0) {
|
| 196 |
+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 197 |
+
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 198 |
+
|
| 199 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 200 |
+
|
| 201 |
+
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
| 202 |
+
|
| 203 |
+
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// Clear padding elements to -inf, so they don't contribute to rowmax
|
| 207 |
+
if (Clamp != 0 &&
|
| 208 |
+
((j + 1) * Bc > KV ||
|
| 209 |
+
(i + 1) * Br > N)) {
|
| 210 |
+
|
| 211 |
+
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
| 212 |
+
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
| 213 |
+
|
| 214 |
+
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
|
| 218 |
+
|
| 219 |
+
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
|
| 220 |
+
|
| 221 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
|
| 222 |
+
|
| 223 |
+
// M = max(rowmax, Mold)
|
| 224 |
+
// P = e^(S - M)
|
| 225 |
+
// eM = e^(Mold - M)
|
| 226 |
+
coopMatPerElementNV(M, rowmax, Max, Mold);
|
| 227 |
+
coopMatPerElementNV(P, S - M, Exp);
|
| 228 |
+
coopMatPerElementNV(eM, Mold - M, Exp);
|
| 229 |
+
|
| 230 |
+
// Clear padding elements to 0, so they don't contribute to rowsum
|
| 231 |
+
if (Clamp != 0 &&
|
| 232 |
+
((j + 1) * Bc > KV ||
|
| 233 |
+
(i + 1) * Br > N)) {
|
| 234 |
+
|
| 235 |
+
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
| 236 |
+
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
| 237 |
+
|
| 238 |
+
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
|
| 242 |
+
|
| 243 |
+
// compute rowsum by multiplying by matrix of all ones.
|
| 244 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
|
| 245 |
+
|
| 246 |
+
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
| 247 |
+
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
| 248 |
+
|
| 249 |
+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
| 250 |
+
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
| 251 |
+
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
| 252 |
+
|
| 253 |
+
L = eM*L + rowsum;
|
| 254 |
+
|
| 255 |
+
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
| 256 |
+
// multiply rather than matrix multiply it has the diagonal element smeared
|
| 257 |
+
// across the row
|
| 258 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
| 259 |
+
|
| 260 |
+
// resize eM by using smear/reduce
|
| 261 |
+
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 262 |
+
|
| 263 |
+
O = eMdiag * O;
|
| 264 |
+
|
| 265 |
+
O = coopMatMulAdd(P_A, V, O);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
| 269 |
+
|
| 270 |
+
// resize L by using smear/reduce
|
| 271 |
+
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 272 |
+
|
| 273 |
+
[[unroll]]
|
| 274 |
+
for (int k = 0; k < Ldiag.length(); ++k) {
|
| 275 |
+
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
O = Ldiag*O;
|
| 279 |
+
|
| 280 |
+
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
| 281 |
+
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
| 282 |
+
|
| 283 |
+
// permute dimensions
|
| 284 |
+
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
| 285 |
+
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 286 |
+
|
| 287 |
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 288 |
+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
|
| 289 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 4 |
+
#extension GL_EXT_shader_16bit_storage : require
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 7 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
| 8 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
| 9 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 10 |
+
|
| 11 |
+
#extension GL_KHR_memory_scope_semantics : enable
|
| 12 |
+
#extension GL_KHR_cooperative_matrix : enable
|
| 13 |
+
#extension GL_NV_cooperative_matrix2 : enable
|
| 14 |
+
#extension GL_EXT_buffer_reference : enable
|
| 15 |
+
#extension GL_KHR_shader_subgroup_ballot : enable
|
| 16 |
+
#extension GL_KHR_shader_subgroup_vote : enable
|
| 17 |
+
|
| 18 |
+
#include "types.comp"
|
| 19 |
+
|
| 20 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 21 |
+
|
| 22 |
+
layout (constant_id = 1) const uint BM = 64;
|
| 23 |
+
layout (constant_id = 2) const uint BN = 64;
|
| 24 |
+
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
| 25 |
+
|
| 26 |
+
layout (push_constant) uniform parameter
|
| 27 |
+
{
|
| 28 |
+
uint M;
|
| 29 |
+
uint N;
|
| 30 |
+
uint K;
|
| 31 |
+
uint stride_a;
|
| 32 |
+
uint stride_b;
|
| 33 |
+
uint stride_d;
|
| 34 |
+
|
| 35 |
+
uint batch_stride_a;
|
| 36 |
+
uint batch_stride_b;
|
| 37 |
+
uint batch_stride_d;
|
| 38 |
+
|
| 39 |
+
#ifdef MUL_MAT_ID
|
| 40 |
+
uint nei0;
|
| 41 |
+
uint nei1;
|
| 42 |
+
uint nbi1;
|
| 43 |
+
uint ne11;
|
| 44 |
+
#else
|
| 45 |
+
uint k_split;
|
| 46 |
+
uint ne02;
|
| 47 |
+
uint ne12;
|
| 48 |
+
uint broadcast2;
|
| 49 |
+
uint broadcast3;
|
| 50 |
+
#endif
|
| 51 |
+
} p;
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
| 55 |
+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
| 56 |
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
| 57 |
+
|
| 58 |
+
#if QUANT_K > 1
|
| 59 |
+
#define DECODEFUNCA , dequantFuncA
|
| 60 |
+
#define MAT_A_TYPE float16_t
|
| 61 |
+
|
| 62 |
+
#include "dequant_funcs_cm2.comp"
|
| 63 |
+
|
| 64 |
+
#else
|
| 65 |
+
#define DECODEFUNCA
|
| 66 |
+
#define MAT_A_TYPE A_TYPE
|
| 67 |
+
#endif
|
| 68 |
+
|
| 69 |
+
#define MAT_B_TYPE B_TYPE
|
| 70 |
+
|
| 71 |
+
#ifdef MUL_MAT_ID
|
| 72 |
+
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
| 73 |
+
|
| 74 |
+
shared u16vec4 row_ids[3072];
|
| 75 |
+
|
| 76 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
| 77 |
+
B_TYPE b[];
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
uint _ne1;
|
| 81 |
+
shared uint _ne1_sh;
|
| 82 |
+
|
| 83 |
+
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 84 |
+
{
|
| 85 |
+
const uint row_i = blockCoords[0];
|
| 86 |
+
|
| 87 |
+
if (row_i >= _ne1) {
|
| 88 |
+
return B_TYPE(0.0);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
const u16vec4 row_idx = row_ids[row_i];
|
| 92 |
+
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
| 93 |
+
|
| 94 |
+
return ret;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
|
| 98 |
+
{
|
| 99 |
+
uint dr = ir * BM + r;
|
| 100 |
+
uint dc = ic * BN + c;
|
| 101 |
+
|
| 102 |
+
if (dr < p.M && dc < _ne1) {
|
| 103 |
+
uint row_i = dc;
|
| 104 |
+
const u16vec4 row_idx = row_ids[row_i];
|
| 105 |
+
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
| 106 |
+
}
|
| 107 |
+
return elem;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
#endif
|
| 111 |
+
|
| 112 |
+
void main() {
|
| 113 |
+
#if defined(DATA_A_IQ4_NL)
|
| 114 |
+
init_iq4nl_shmem();
|
| 115 |
+
#endif
|
| 116 |
+
|
| 117 |
+
#ifdef MUL_MAT_ID
|
| 118 |
+
const uint expert_idx = gl_GlobalInvocationID.z;
|
| 119 |
+
#else
|
| 120 |
+
const uint batch_idx = gl_GlobalInvocationID.z;
|
| 121 |
+
|
| 122 |
+
const uint i13 = batch_idx / p.ne12;
|
| 123 |
+
const uint i12 = batch_idx % p.ne12;
|
| 124 |
+
|
| 125 |
+
const uint i03 = i13 / p.broadcast3;
|
| 126 |
+
const uint i02 = i12 / p.broadcast2;
|
| 127 |
+
|
| 128 |
+
const uint batch_idx_a = i03 * p.ne02 + i02;
|
| 129 |
+
#endif
|
| 130 |
+
|
| 131 |
+
const uint blocks_m = (p.M + BM - 1) / BM;
|
| 132 |
+
const uint ir = gl_WorkGroupID.x % blocks_m;
|
| 133 |
+
const uint ik = gl_WorkGroupID.x / blocks_m;
|
| 134 |
+
const uint ic = gl_WorkGroupID.y;
|
| 135 |
+
|
| 136 |
+
#ifdef MUL_MAT_ID
|
| 137 |
+
// Spread the search across all elements in the first subgroup
|
| 138 |
+
if (gl_SubgroupID == 0) {
|
| 139 |
+
_ne1 = 0;
|
| 140 |
+
uint num_elements = p.nei1 * p.nei0;
|
| 141 |
+
|
| 142 |
+
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
| 143 |
+
bool in_range = i < num_elements;
|
| 144 |
+
uint ii0 = i % p.nei0;
|
| 145 |
+
uint ii1 = i / p.nei0;
|
| 146 |
+
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
| 147 |
+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
| 148 |
+
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
| 149 |
+
if (in_range && id == expert_idx) {
|
| 150 |
+
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
| 151 |
+
}
|
| 152 |
+
_ne1 += subgroupBallotBitCount(ballot);
|
| 153 |
+
}
|
| 154 |
+
_ne1_sh = _ne1;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
barrier();
|
| 158 |
+
|
| 159 |
+
_ne1 = _ne1_sh;
|
| 160 |
+
|
| 161 |
+
// Workgroup has no work
|
| 162 |
+
if (ic * BN >= _ne1) return;
|
| 163 |
+
#endif
|
| 164 |
+
|
| 165 |
+
#ifdef MUL_MAT_ID
|
| 166 |
+
uint start_k = 0;
|
| 167 |
+
const uint end_k = p.K;
|
| 168 |
+
#else
|
| 169 |
+
uint start_k = ik * p.k_split;
|
| 170 |
+
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
| 171 |
+
#endif
|
| 172 |
+
|
| 173 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
| 174 |
+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
| 175 |
+
|
| 176 |
+
#ifdef MUL_MAT_ID
|
| 177 |
+
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
|
| 178 |
+
uint pos_b = 0;
|
| 179 |
+
#else
|
| 180 |
+
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
|
| 181 |
+
uint pos_b = batch_idx * p.batch_stride_b;
|
| 182 |
+
#endif
|
| 183 |
+
|
| 184 |
+
uint stride_a = p.stride_a / QUANT_K;
|
| 185 |
+
uint stride_b = p.stride_b;
|
| 186 |
+
|
| 187 |
+
// Hint to the compiler that values are aligned (want 16B alignment).
|
| 188 |
+
// Quants are always block-aligned, no alignment needed.
|
| 189 |
+
#if ALIGNED
|
| 190 |
+
#if QUANT_K == 1
|
| 191 |
+
stride_a &= ~7;
|
| 192 |
+
#endif
|
| 193 |
+
stride_b &= ~7;
|
| 194 |
+
#endif
|
| 195 |
+
|
| 196 |
+
// Create layouts for both clamped and unclamped accesses
|
| 197 |
+
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
| 198 |
+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 199 |
+
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
| 200 |
+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 201 |
+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 202 |
+
|
| 203 |
+
#if QUANT_K > 1
|
| 204 |
+
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
| 205 |
+
tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
|
| 206 |
+
#endif
|
| 207 |
+
|
| 208 |
+
// Use end_k rather than p.K as the dimension because that's what
|
| 209 |
+
// we need to bound check against when using split_k
|
| 210 |
+
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
|
| 211 |
+
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
|
| 212 |
+
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
|
| 213 |
+
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
| 214 |
+
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
|
| 215 |
+
|
| 216 |
+
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
| 217 |
+
|
| 218 |
+
#if !defined(MUL_MAT_ID)
|
| 219 |
+
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
| 220 |
+
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
| 221 |
+
#if QUANT_K == 1
|
| 222 |
+
(stride_a % 8) == 0 &&
|
| 223 |
+
#endif
|
| 224 |
+
(stride_b % 8) == 0 && (start_k % 8) == 0) {
|
| 225 |
+
// Hint to the compiler that values are aligned (want 16B alignment)
|
| 226 |
+
start_k &= ~7;
|
| 227 |
+
stride_b &= ~7;
|
| 228 |
+
#if QUANT_K == 1
|
| 229 |
+
stride_a &= ~7;
|
| 230 |
+
#endif
|
| 231 |
+
|
| 232 |
+
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
| 233 |
+
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
| 234 |
+
|
| 235 |
+
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
| 236 |
+
|
| 237 |
+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
| 238 |
+
|
| 239 |
+
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 240 |
+
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
| 241 |
+
|
| 242 |
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 243 |
+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
| 244 |
+
|
| 245 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
| 246 |
+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
| 247 |
+
|
| 248 |
+
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
| 249 |
+
}
|
| 250 |
+
} else
|
| 251 |
+
#endif // !defined(MUL_MAT_ID)
|
| 252 |
+
{
|
| 253 |
+
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
| 254 |
+
|
| 255 |
+
tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
|
| 256 |
+
|
| 257 |
+
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
| 258 |
+
|
| 259 |
+
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
| 260 |
+
|
| 261 |
+
[[dont_unroll]]
|
| 262 |
+
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
| 263 |
+
|
| 264 |
+
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
| 265 |
+
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
| 266 |
+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
|
| 267 |
+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
|
| 268 |
+
|
| 269 |
+
// Clamping is expensive, so detect different code paths for each combination
|
| 270 |
+
// of A and B needing clamping.
|
| 271 |
+
bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
|
| 272 |
+
#ifdef MUL_MAT_ID
|
| 273 |
+
bool unclampedB = true;
|
| 274 |
+
#else
|
| 275 |
+
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
|
| 276 |
+
#endif
|
| 277 |
+
if (unclampedA && unclampedB) {
|
| 278 |
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
| 279 |
+
#ifdef MUL_MAT_ID
|
| 280 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
| 281 |
+
#else
|
| 282 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
| 283 |
+
#endif
|
| 284 |
+
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
| 285 |
+
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
| 286 |
+
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
| 287 |
+
} else if (unclampedA && !unclampedB) {
|
| 288 |
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
| 289 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
| 290 |
+
|
| 291 |
+
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
| 292 |
+
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
| 293 |
+
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
| 294 |
+
} else if (!unclampedA && unclampedB) {
|
| 295 |
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 296 |
+
#ifdef MUL_MAT_ID
|
| 297 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
| 298 |
+
#else
|
| 299 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
| 300 |
+
#endif
|
| 301 |
+
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
| 302 |
+
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
| 303 |
+
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
| 304 |
+
} else if (!unclampedA && !unclampedB) {
|
| 305 |
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
| 306 |
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
| 307 |
+
|
| 308 |
+
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
| 309 |
+
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
| 310 |
+
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// Convert from ACC_TYPE to D_TYPE
|
| 316 |
+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
| 317 |
+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
| 318 |
+
|
| 319 |
+
#ifdef MUL_MAT_ID
|
| 320 |
+
// Call callback to store each element, remapping row through shared memory
|
| 321 |
+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
| 322 |
+
#else
|
| 323 |
+
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
| 324 |
+
|
| 325 |
+
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
| 326 |
+
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
| 327 |
+
#endif
|
| 328 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -30,6 +30,8 @@
|
|
| 30 |
#include <fcntl.h>
|
| 31 |
#endif
|
| 32 |
|
|
|
|
|
|
|
| 33 |
#define ASYNCIO_CONCURRENCY 64
|
| 34 |
|
| 35 |
std::mutex lock;
|
|
@@ -196,15 +198,17 @@ static uint32_t compile_count = 0;
|
|
| 196 |
static std::mutex compile_count_mutex;
|
| 197 |
static std::condition_variable compile_count_cond;
|
| 198 |
|
| 199 |
-
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
|
| 200 |
-
std::string name = _name + (fp16 ? "" : "_fp32");
|
| 201 |
std::string out_fname = join_paths(output_dir, name + ".spv");
|
| 202 |
std::string in_path = join_paths(input_dir, in_fname);
|
| 203 |
|
|
|
|
|
|
|
| 204 |
#ifdef _WIN32
|
| 205 |
-
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute",
|
| 206 |
#else
|
| 207 |
-
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute",
|
| 208 |
#endif
|
| 209 |
|
| 210 |
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
|
|
@@ -254,7 +258,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
|
| 254 |
}
|
| 255 |
|
| 256 |
static std::vector<std::future<void>> compiles;
|
| 257 |
-
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
|
| 258 |
{
|
| 259 |
// wait until fewer than N compiles are in progress.
|
| 260 |
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
|
|
@@ -265,15 +269,15 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
|
| 265 |
}
|
| 266 |
compile_count++;
|
| 267 |
}
|
| 268 |
-
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16));
|
| 269 |
}
|
| 270 |
|
| 271 |
-
void matmul_shaders(bool fp16, bool matmul_id) {
|
| 272 |
-
std::string load_vec = fp16 ? "8" : "4";
|
| 273 |
-
std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
|
| 274 |
-
std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
|
| 275 |
|
| 276 |
-
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
|
| 277 |
std::string shader_name = "matmul";
|
| 278 |
|
| 279 |
if (matmul_id) {
|
|
@@ -285,21 +289,31 @@ void matmul_shaders(bool fp16, bool matmul_id) {
|
|
| 285 |
base_dict["FLOAT16"] = "1";
|
| 286 |
}
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
// Shaders with f16 B_TYPE
|
| 289 |
-
string_to_spv(shader_name + "_f32_f16",
|
| 290 |
-
string_to_spv(shader_name + "_f32_f16_aligned",
|
| 291 |
|
| 292 |
-
string_to_spv(shader_name + "
|
| 293 |
-
string_to_spv(shader_name + "
|
| 294 |
|
| 295 |
for (const auto& tname : type_names) {
|
| 296 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 297 |
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
| 298 |
-
std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
|
| 299 |
// For aligned matmul loads
|
| 300 |
-
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
|
| 301 |
-
|
| 302 |
-
string_to_spv(shader_name + "_" + tname + "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
}
|
| 304 |
}
|
| 305 |
|
|
@@ -307,11 +321,50 @@ void process_shaders() {
|
|
| 307 |
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
| 308 |
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
| 309 |
|
|
|
|
| 310 |
for (const auto& fp16 : {false, true}) {
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
}
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
for (const auto& tname : type_names) {
|
| 316 |
// mul mat vec
|
| 317 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
|
|
| 30 |
#include <fcntl.h>
|
| 31 |
#endif
|
| 32 |
|
| 33 |
+
#include <vulkan/vulkan_core.h>
|
| 34 |
+
|
| 35 |
#define ASYNCIO_CONCURRENCY 64
|
| 36 |
|
| 37 |
std::mutex lock;
|
|
|
|
| 198 |
static std::mutex compile_count_mutex;
|
| 199 |
static std::condition_variable compile_count_cond;
|
| 200 |
|
| 201 |
+
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
|
| 202 |
+
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
| 203 |
std::string out_fname = join_paths(output_dir, name + ".spv");
|
| 204 |
std::string in_path = join_paths(input_dir, in_fname);
|
| 205 |
|
| 206 |
+
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
|
| 207 |
+
|
| 208 |
#ifdef _WIN32
|
| 209 |
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
|
| 210 |
#else
|
| 211 |
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o", out_fname};
|
| 212 |
#endif
|
| 213 |
|
| 214 |
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
|
|
|
|
| 258 |
}
|
| 259 |
|
| 260 |
static std::vector<std::future<void>> compiles;
|
| 261 |
+
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
|
| 262 |
{
|
| 263 |
// wait until fewer than N compiles are in progress.
|
| 264 |
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
|
|
|
|
| 269 |
}
|
| 270 |
compile_count++;
|
| 271 |
}
|
| 272 |
+
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
|
| 273 |
}
|
| 274 |
|
| 275 |
+
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
|
| 276 |
+
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
| 277 |
+
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
| 278 |
+
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
| 279 |
|
| 280 |
+
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
|
| 281 |
std::string shader_name = "matmul";
|
| 282 |
|
| 283 |
if (matmul_id) {
|
|
|
|
| 289 |
base_dict["FLOAT16"] = "1";
|
| 290 |
}
|
| 291 |
|
| 292 |
+
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
| 293 |
+
|
| 294 |
+
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 295 |
+
|
| 296 |
// Shaders with f16 B_TYPE
|
| 297 |
+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
|
| 298 |
+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 299 |
|
| 300 |
+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 301 |
+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
|
| 302 |
|
| 303 |
for (const auto& tname : type_names) {
|
| 304 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 305 |
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
| 306 |
+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
|
| 307 |
// For aligned matmul loads
|
| 308 |
+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
| 309 |
+
|
| 310 |
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
|
| 311 |
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 312 |
+
|
| 313 |
+
if (tname != "f16" && tname != "f32") {
|
| 314 |
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
|
| 315 |
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 316 |
+
}
|
| 317 |
}
|
| 318 |
}
|
| 319 |
|
|
|
|
| 321 |
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
| 322 |
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
| 323 |
|
| 324 |
+
// matmul
|
| 325 |
for (const auto& fp16 : {false, true}) {
|
| 326 |
+
for (const auto& matmul_id : {false, true}) {
|
| 327 |
+
for (const auto& coopmat2 : {false, true}) {
|
| 328 |
+
for (const auto& f16acc : {false, true}) {
|
| 329 |
+
#if !defined(VK_NV_cooperative_matrix2)
|
| 330 |
+
if (coopmat2) {
|
| 331 |
+
continue;
|
| 332 |
+
}
|
| 333 |
+
#endif
|
| 334 |
+
if (coopmat2 && !fp16) {
|
| 335 |
+
continue;
|
| 336 |
+
}
|
| 337 |
+
if (!coopmat2 && f16acc) {
|
| 338 |
+
continue;
|
| 339 |
+
}
|
| 340 |
+
matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
}
|
| 345 |
|
| 346 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 347 |
+
// flash attention
|
| 348 |
+
for (const auto& f16acc : {false, true}) {
|
| 349 |
+
std::string acctype = f16acc ? "float16_t" : "float";
|
| 350 |
+
|
| 351 |
+
for (const auto& tname : type_names) {
|
| 352 |
+
if (tname == "f32") {
|
| 353 |
+
continue;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
if (tname == "f16") {
|
| 357 |
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 358 |
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
|
| 359 |
+
} else {
|
| 360 |
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 361 |
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 362 |
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
#endif
|
| 367 |
+
|
| 368 |
for (const auto& tname : type_names) {
|
| 369 |
// mul mat vec
|
| 370 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|