jeffbolznv commited on
Commit
5e70c43
·
1 Parent(s): 21f8a02

vulkan: optimize coopmat2 dequant functions (llama/10855)

Browse files

Change the code to do 16b loads when possible and extract the appropriate
component late, so the code is effectively decoding a pair of elements and
then selecting one. This can allow more commoning to happen in the compiler
when neighboring elements are loaded.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp CHANGED
@@ -10,9 +10,10 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
10
  const float16_t d = bl.block.d;
11
  const uint idx = coordInBlock[1];
12
  const uint shift = (idx & 0x10) >> 2;
13
- uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
14
  qs >>= shift;
15
- qs &= 0xF;
 
16
  float16_t ret = (float16_t(qs) - float16_t(8)) * d;
17
  return ret;
18
  }
@@ -152,15 +153,17 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
152
  block_q4_K block;
153
  };
154
 
 
 
 
 
155
  float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
156
  {
 
157
  const uint idx = coordInBlock[1];
158
- const uint iqs = idx;
159
 
160
- const uint n = iqs / 64; // 0,1,2,3
161
- const uint b = (iqs % 64) / 32; // 0,1
162
  const uint is = (idx & 0xE0) >> 5; // 0..7
163
- const uint qsi = n * 32 + (iqs % 32); // 0..127
164
 
165
  const f16vec2 loadd = bl.block.d;
166
 
@@ -184,9 +187,11 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
184
  const float16_t d = loadd.x * float16_t(sc);
185
  const float16_t m = loadd.y * float16_t(mbyte);
186
 
187
- uint32_t dmask = 0xF << (b * 4);
 
 
188
 
189
- float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
190
 
191
  return ret;
192
  }
@@ -195,18 +200,19 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
195
  block_q5_K block;
196
  };
197
 
 
 
 
 
198
  float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
199
  {
 
200
  const uint idx = coordInBlock[1];
201
- const uint iqs = idx;
202
 
203
- const uint n = iqs / 64; // 0,1,2,3
204
- const uint b = (iqs % 64) / 32; // 0,1
205
  const uint is = (idx & 0xE0) >> 5; // 0..7
206
- const uint qsi = n * 32 + (iqs % 32); // 0..127
207
- const uint qhi = (iqs % 32); // 0..31
208
 
209
- const uint8_t hm = uint8_t(1 << (iqs / 32));
210
 
211
  const f16vec2 loadd = bl.block.d;
212
 
@@ -230,9 +236,15 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
230
  const float16_t d = loadd.x * float16_t(sc);
231
  const float16_t m = loadd.y * float16_t(mbyte);
232
 
233
- uint32_t dmask = 0xF << (b * 4);
 
 
234
 
235
- float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
 
 
 
 
236
 
237
  return ret;
238
  }
@@ -241,22 +253,30 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_
241
  block_q6_K block;
242
  };
243
 
 
 
 
 
244
  float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
245
  {
 
246
  const uint idx = coordInBlock[1];
247
- const uint iqs = idx;
248
 
249
- const uint n = iqs / 128; // 0,1
250
- const uint b = (iqs % 128) / 64; // 0,1
251
- const uint is_b = (iqs % 32) / 16; // 0,1
252
- const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
253
- const uint is = 8 * n + qhshift + is_b; // 0..15
254
- const uint qsi = n * 64 + (iqs % 64); // 0..127
255
- const uint qhi = n * 32 + (iqs % 32); // 0..63
256
 
257
  const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
258
 
259
- float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
 
 
 
 
 
 
 
 
260
 
261
  return ret;
262
  }
 
10
  const float16_t d = bl.block.d;
11
  const uint idx = coordInBlock[1];
12
  const uint shift = (idx & 0x10) >> 2;
13
+ uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
14
  qs >>= shift;
15
+ qs &= 0x0F0F;
16
+ qs = unpack8(qs)[idx & 1];
17
  float16_t ret = (float16_t(qs) - float16_t(8)) * d;
18
  return ret;
19
  }
 
153
  block_q4_K block;
154
  };
155
 
156
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
157
+ block_q4_K_packed16 block;
158
+ };
159
+
160
  float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
161
  {
162
+ decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
163
  const uint idx = coordInBlock[1];
 
164
 
165
+ const uint b = (idx & 0x20) >> 5; // 0,1
 
166
  const uint is = (idx & 0xE0) >> 5; // 0..7
 
167
 
168
  const f16vec2 loadd = bl.block.d;
169
 
 
187
  const float16_t d = loadd.x * float16_t(sc);
188
  const float16_t m = loadd.y * float16_t(mbyte);
189
 
190
+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
191
+ qs = (qs >> (b * 4)) & 0x0F0F;
192
+ qs = unpack8(qs)[idx & 1];
193
 
194
+ float16_t ret = d * float16_t(qs) - m;
195
 
196
  return ret;
197
  }
 
200
  block_q5_K block;
201
  };
202
 
203
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
204
+ block_q5_K_packed16 block;
205
+ };
206
+
207
  float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
208
  {
209
+ decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
210
  const uint idx = coordInBlock[1];
 
211
 
212
+ const uint b = (idx & 0x20) >> 5; // 0,1
 
213
  const uint is = (idx & 0xE0) >> 5; // 0..7
 
 
214
 
215
+ const uint32_t hm = 0x0101 << is;
216
 
217
  const f16vec2 loadd = bl.block.d;
218
 
 
236
  const float16_t d = loadd.x * float16_t(sc);
237
  const float16_t m = loadd.y * float16_t(mbyte);
238
 
239
+ uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
240
+ qh = qh & hm;
241
+ qh = unpack8(qh)[idx & 1];
242
 
243
+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
244
+ qs = (qs >> (b * 4)) & 0x0F0F;
245
+ qs = unpack8(qs)[idx & 1];
246
+
247
+ float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
248
 
249
  return ret;
250
  }
 
253
  block_q6_K block;
254
  };
255
 
256
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
257
+ block_q6_K_packed16 block;
258
+ };
259
+
260
  float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
261
  {
262
+ decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
263
  const uint idx = coordInBlock[1];
 
264
 
265
+ const uint b = (idx & 0x40) >> 6; // 0,1
266
+ const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
267
+ const uint is = (idx & 0xF0) >> 4; // 0..15
 
 
 
 
268
 
269
  const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
270
 
271
+ uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
272
+ ql = (ql >> (b * 4)) & 0x0F0F;
273
+
274
+ uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
275
+ qh = ((qh >> qhshift) & 0x0303) << 4;
276
+
277
+ int q = unpack8(ql | qh)[idx & 1];
278
+
279
+ float16_t ret = dscale * float16_t(q - 32);
280
 
281
  return ret;
282
  }