Spaces:
Running
Running
Commit
·
5e70c43
1
Parent(s):
21f8a02
vulkan: optimize coopmat2 dequant functions (llama/10855)
Browse filesChange the code to do 16b loads when possible and extract the appropriate
component late, so the code is effectively decoding a pair of elements and
then selecting one. This can allow more commoning to happen in the compiler
when neighboring elements are loaded.
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
CHANGED
|
@@ -10,9 +10,10 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
|
|
| 10 |
const float16_t d = bl.block.d;
|
| 11 |
const uint idx = coordInBlock[1];
|
| 12 |
const uint shift = (idx & 0x10) >> 2;
|
| 13 |
-
uint32_t qs =
|
| 14 |
qs >>= shift;
|
| 15 |
-
qs &=
|
|
|
|
| 16 |
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
|
| 17 |
return ret;
|
| 18 |
}
|
|
@@ -152,15 +153,17 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
|
| 152 |
block_q4_K block;
|
| 153 |
};
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 156 |
{
|
|
|
|
| 157 |
const uint idx = coordInBlock[1];
|
| 158 |
-
const uint iqs = idx;
|
| 159 |
|
| 160 |
-
const uint
|
| 161 |
-
const uint b = (iqs % 64) / 32; // 0,1
|
| 162 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 163 |
-
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
| 164 |
|
| 165 |
const f16vec2 loadd = bl.block.d;
|
| 166 |
|
|
@@ -184,9 +187,11 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
|
| 184 |
const float16_t d = loadd.x * float16_t(sc);
|
| 185 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
float16_t ret = d * float16_t(
|
| 190 |
|
| 191 |
return ret;
|
| 192 |
}
|
|
@@ -195,18 +200,19 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
|
|
| 195 |
block_q5_K block;
|
| 196 |
};
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 199 |
{
|
|
|
|
| 200 |
const uint idx = coordInBlock[1];
|
| 201 |
-
const uint iqs = idx;
|
| 202 |
|
| 203 |
-
const uint
|
| 204 |
-
const uint b = (iqs % 64) / 32; // 0,1
|
| 205 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 206 |
-
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
| 207 |
-
const uint qhi = (iqs % 32); // 0..31
|
| 208 |
|
| 209 |
-
const
|
| 210 |
|
| 211 |
const f16vec2 loadd = bl.block.d;
|
| 212 |
|
|
@@ -230,9 +236,15 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
|
| 230 |
const float16_t d = loadd.x * float16_t(sc);
|
| 231 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 232 |
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
return ret;
|
| 238 |
}
|
|
@@ -241,22 +253,30 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_
|
|
| 241 |
block_q6_K block;
|
| 242 |
};
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 245 |
{
|
|
|
|
| 246 |
const uint idx = coordInBlock[1];
|
| 247 |
-
const uint iqs = idx;
|
| 248 |
|
| 249 |
-
const uint
|
| 250 |
-
const uint
|
| 251 |
-
const uint
|
| 252 |
-
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
|
| 253 |
-
const uint is = 8 * n + qhshift + is_b; // 0..15
|
| 254 |
-
const uint qsi = n * 64 + (iqs % 64); // 0..127
|
| 255 |
-
const uint qhi = n * 32 + (iqs % 32); // 0..63
|
| 256 |
|
| 257 |
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
| 258 |
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
return ret;
|
| 262 |
}
|
|
|
|
| 10 |
const float16_t d = bl.block.d;
|
| 11 |
const uint idx = coordInBlock[1];
|
| 12 |
const uint shift = (idx & 0x10) >> 2;
|
| 13 |
+
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
|
| 14 |
qs >>= shift;
|
| 15 |
+
qs &= 0x0F0F;
|
| 16 |
+
qs = unpack8(qs)[idx & 1];
|
| 17 |
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
|
| 18 |
return ret;
|
| 19 |
}
|
|
|
|
| 153 |
block_q4_K block;
|
| 154 |
};
|
| 155 |
|
| 156 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
|
| 157 |
+
block_q4_K_packed16 block;
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 161 |
{
|
| 162 |
+
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
|
| 163 |
const uint idx = coordInBlock[1];
|
|
|
|
| 164 |
|
| 165 |
+
const uint b = (idx & 0x20) >> 5; // 0,1
|
|
|
|
| 166 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
|
|
|
| 167 |
|
| 168 |
const f16vec2 loadd = bl.block.d;
|
| 169 |
|
|
|
|
| 187 |
const float16_t d = loadd.x * float16_t(sc);
|
| 188 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 189 |
|
| 190 |
+
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 191 |
+
qs = (qs >> (b * 4)) & 0x0F0F;
|
| 192 |
+
qs = unpack8(qs)[idx & 1];
|
| 193 |
|
| 194 |
+
float16_t ret = d * float16_t(qs) - m;
|
| 195 |
|
| 196 |
return ret;
|
| 197 |
}
|
|
|
|
| 200 |
block_q5_K block;
|
| 201 |
};
|
| 202 |
|
| 203 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
|
| 204 |
+
block_q5_K_packed16 block;
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 208 |
{
|
| 209 |
+
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
|
| 210 |
const uint idx = coordInBlock[1];
|
|
|
|
| 211 |
|
| 212 |
+
const uint b = (idx & 0x20) >> 5; // 0,1
|
|
|
|
| 213 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
const uint32_t hm = 0x0101 << is;
|
| 216 |
|
| 217 |
const f16vec2 loadd = bl.block.d;
|
| 218 |
|
|
|
|
| 236 |
const float16_t d = loadd.x * float16_t(sc);
|
| 237 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 238 |
|
| 239 |
+
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
|
| 240 |
+
qh = qh & hm;
|
| 241 |
+
qh = unpack8(qh)[idx & 1];
|
| 242 |
|
| 243 |
+
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 244 |
+
qs = (qs >> (b * 4)) & 0x0F0F;
|
| 245 |
+
qs = unpack8(qs)[idx & 1];
|
| 246 |
+
|
| 247 |
+
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
|
| 248 |
|
| 249 |
return ret;
|
| 250 |
}
|
|
|
|
| 253 |
block_q6_K block;
|
| 254 |
};
|
| 255 |
|
| 256 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
|
| 257 |
+
block_q6_K_packed16 block;
|
| 258 |
+
};
|
| 259 |
+
|
| 260 |
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 261 |
{
|
| 262 |
+
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
|
| 263 |
const uint idx = coordInBlock[1];
|
|
|
|
| 264 |
|
| 265 |
+
const uint b = (idx & 0x40) >> 6; // 0,1
|
| 266 |
+
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
|
| 267 |
+
const uint is = (idx & 0xF0) >> 4; // 0..15
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
| 270 |
|
| 271 |
+
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
|
| 272 |
+
ql = (ql >> (b * 4)) & 0x0F0F;
|
| 273 |
+
|
| 274 |
+
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
|
| 275 |
+
qh = ((qh >> qhshift) & 0x0303) << 4;
|
| 276 |
+
|
| 277 |
+
int q = unpack8(ql | qh)[idx & 1];
|
| 278 |
+
|
| 279 |
+
float16_t ret = dscale * float16_t(q - 32);
|
| 280 |
|
| 281 |
return ret;
|
| 282 |
}
|