Spaces:
Running
Running
ggml : IQ3_S improvements (llama/5829)
Browse files* iq3_s: somewhat faster AVX2 dot product
On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.
* iq3_s: somewhat faster ARM_NEON dot product
Still dog slow - 10.7 t/s up from 9.9 t/s.
* iq3_s: another small ARM_NEON improvement
10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.
* iq3_s: minor improvement on Metal
49.4 t/s -> 50.3 t/s
* iq3_s: PPL improvement
E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.
* iq3_s: use new grid everywhere
* Fix ARM_NEON
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +71 -72
- ggml-metal.metal +77 -75
- ggml-quants.c +162 -118
ggml-cuda.cu
CHANGED
|
@@ -2061,74 +2061,73 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
|
|
| 2061 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 2062 |
};
|
| 2063 |
|
| 2064 |
-
static const __device__ uint32_t
|
| 2065 |
-
|
| 2066 |
-
|
| 2067 |
-
|
| 2068 |
-
|
| 2069 |
-
|
| 2070 |
-
|
| 2071 |
-
|
| 2072 |
-
|
| 2073 |
-
|
| 2074 |
-
|
| 2075 |
-
|
| 2076 |
-
|
| 2077 |
-
|
| 2078 |
-
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
-
|
| 2083 |
-
|
| 2084 |
-
|
| 2085 |
-
|
| 2086 |
-
|
| 2087 |
-
|
| 2088 |
-
|
| 2089 |
-
|
| 2090 |
-
|
| 2091 |
-
|
| 2092 |
-
|
| 2093 |
-
|
| 2094 |
-
|
| 2095 |
-
|
| 2096 |
-
|
| 2097 |
-
|
| 2098 |
-
|
| 2099 |
-
|
| 2100 |
-
|
| 2101 |
-
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
|
| 2105 |
-
|
| 2106 |
-
|
| 2107 |
-
|
| 2108 |
-
|
| 2109 |
-
|
| 2110 |
-
|
| 2111 |
-
|
| 2112 |
-
|
| 2113 |
-
|
| 2114 |
-
|
| 2115 |
-
|
| 2116 |
-
|
| 2117 |
-
|
| 2118 |
-
|
| 2119 |
-
|
| 2120 |
-
|
| 2121 |
-
|
| 2122 |
-
|
| 2123 |
-
|
| 2124 |
-
|
| 2125 |
-
|
| 2126 |
-
|
| 2127 |
-
|
| 2128 |
-
|
| 2129 |
};
|
| 2130 |
|
| 2131 |
-
|
| 2132 |
static const __device__ uint64_t iq1s_grid[512] = {
|
| 2133 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 2134 |
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
@@ -2435,9 +2434,9 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
|
| 2435 |
const int ib = tid%8; // 0...7
|
| 2436 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 2437 |
const uint8_t * qs = x[i].qs + 8*ib;
|
| 2438 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 2439 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 2440 |
-
const float d = (float)x[i].d * (
|
| 2441 |
const uint8_t signs = x[i].signs[4*ib + il];
|
| 2442 |
for (int j = 0; j < 4; ++j) {
|
| 2443 |
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
|
@@ -5254,8 +5253,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|
| 5254 |
const int8_t * q8 = bq8_1[ib32].qs;
|
| 5255 |
int sumi = 0;
|
| 5256 |
for (int l = 0; l < 4; ++l) {
|
| 5257 |
-
const uint32_t * grid1 =
|
| 5258 |
-
const uint32_t * grid2 =
|
| 5259 |
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
| 5260 |
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
| 5261 |
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
|
|
@@ -5264,7 +5263,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|
| 5264 |
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
| 5265 |
q8 += 8;
|
| 5266 |
}
|
| 5267 |
-
const float d = (float)bq2->d * (
|
| 5268 |
return d * sumi;
|
| 5269 |
#else
|
| 5270 |
assert(false);
|
|
|
|
| 2061 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 2062 |
};
|
| 2063 |
|
| 2064 |
+
static const __device__ uint32_t iq3s_grid[512] = {
|
| 2065 |
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
| 2066 |
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
| 2067 |
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
| 2068 |
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
| 2069 |
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
| 2070 |
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
| 2071 |
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
| 2072 |
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
| 2073 |
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
| 2074 |
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
| 2075 |
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
| 2076 |
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
| 2077 |
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
| 2078 |
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
| 2079 |
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
| 2080 |
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
| 2081 |
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
| 2082 |
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
| 2083 |
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
| 2084 |
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
| 2085 |
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
| 2086 |
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
| 2087 |
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
| 2088 |
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
| 2089 |
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
| 2090 |
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
| 2091 |
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
| 2092 |
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
| 2093 |
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
| 2094 |
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
| 2095 |
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
| 2096 |
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
| 2097 |
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
| 2098 |
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
| 2099 |
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
| 2100 |
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
| 2101 |
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
| 2102 |
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
| 2103 |
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
| 2104 |
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
| 2105 |
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
| 2106 |
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
| 2107 |
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
| 2108 |
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
| 2109 |
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
| 2110 |
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
| 2111 |
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
| 2112 |
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
| 2113 |
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
| 2114 |
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
| 2115 |
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
| 2116 |
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
| 2117 |
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
| 2118 |
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
| 2119 |
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
| 2120 |
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
| 2121 |
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
| 2122 |
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
| 2123 |
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
| 2124 |
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
| 2125 |
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
| 2126 |
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
| 2127 |
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
| 2128 |
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
| 2129 |
};
|
| 2130 |
|
|
|
|
| 2131 |
static const __device__ uint64_t iq1s_grid[512] = {
|
| 2132 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 2133 |
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
|
|
| 2434 |
const int ib = tid%8; // 0...7
|
| 2435 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 2436 |
const uint8_t * qs = x[i].qs + 8*ib;
|
| 2437 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
| 2438 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
| 2439 |
+
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
|
| 2440 |
const uint8_t signs = x[i].signs[4*ib + il];
|
| 2441 |
for (int j = 0; j < 4; ++j) {
|
| 2442 |
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
|
|
|
| 5253 |
const int8_t * q8 = bq8_1[ib32].qs;
|
| 5254 |
int sumi = 0;
|
| 5255 |
for (int l = 0; l < 4; ++l) {
|
| 5256 |
+
const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
|
| 5257 |
+
const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
|
| 5258 |
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
| 5259 |
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
| 5260 |
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
|
|
|
|
| 5263 |
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
| 5264 |
q8 += 8;
|
| 5265 |
}
|
| 5266 |
+
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
|
| 5267 |
return d * sumi;
|
| 5268 |
#else
|
| 5269 |
assert(false);
|
ggml-metal.metal
CHANGED
|
@@ -4130,71 +4130,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
| 4130 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 4131 |
};
|
| 4132 |
|
| 4133 |
-
constexpr constant static uint32_t
|
| 4134 |
-
|
| 4135 |
-
|
| 4136 |
-
|
| 4137 |
-
|
| 4138 |
-
|
| 4139 |
-
|
| 4140 |
-
|
| 4141 |
-
|
| 4142 |
-
|
| 4143 |
-
|
| 4144 |
-
|
| 4145 |
-
|
| 4146 |
-
|
| 4147 |
-
|
| 4148 |
-
|
| 4149 |
-
|
| 4150 |
-
|
| 4151 |
-
|
| 4152 |
-
|
| 4153 |
-
|
| 4154 |
-
|
| 4155 |
-
|
| 4156 |
-
|
| 4157 |
-
|
| 4158 |
-
|
| 4159 |
-
|
| 4160 |
-
|
| 4161 |
-
|
| 4162 |
-
|
| 4163 |
-
|
| 4164 |
-
|
| 4165 |
-
|
| 4166 |
-
|
| 4167 |
-
|
| 4168 |
-
|
| 4169 |
-
|
| 4170 |
-
|
| 4171 |
-
|
| 4172 |
-
|
| 4173 |
-
|
| 4174 |
-
|
| 4175 |
-
|
| 4176 |
-
|
| 4177 |
-
|
| 4178 |
-
|
| 4179 |
-
|
| 4180 |
-
|
| 4181 |
-
|
| 4182 |
-
|
| 4183 |
-
|
| 4184 |
-
|
| 4185 |
-
|
| 4186 |
-
|
| 4187 |
-
|
| 4188 |
-
|
| 4189 |
-
|
| 4190 |
-
|
| 4191 |
-
|
| 4192 |
-
|
| 4193 |
-
|
| 4194 |
-
|
| 4195 |
-
|
| 4196 |
-
|
| 4197 |
-
|
| 4198 |
};
|
| 4199 |
|
| 4200 |
#define NGRID_IQ1S 512
|
|
@@ -4785,7 +4785,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4785 |
{
|
| 4786 |
int nval = 8;
|
| 4787 |
int pos = (32*sgitg + tiisg)*nval;
|
| 4788 |
-
for (int i = 0; i < nval; ++i) values[pos + i] =
|
| 4789 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4790 |
}
|
| 4791 |
|
|
@@ -4812,12 +4812,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4812 |
for (int row = 0; row < N_DST; row++) {
|
| 4813 |
|
| 4814 |
const float db = dh[0];
|
| 4815 |
-
const float d = db * (
|
| 4816 |
|
| 4817 |
float2 sum = {0};
|
| 4818 |
for (int l = 0; l < 4; ++l) {
|
| 4819 |
-
const threadgroup
|
| 4820 |
-
const threadgroup
|
|
|
|
|
|
|
| 4821 |
for (int j = 0; j < 4; ++j) {
|
| 4822 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
| 4823 |
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
@@ -4838,7 +4840,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4838 |
for (int row = 0; row < N_DST; ++row) {
|
| 4839 |
all_sum = simd_sum(sumf[row]);
|
| 4840 |
if (tiisg == 0) {
|
| 4841 |
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum
|
| 4842 |
}
|
| 4843 |
}
|
| 4844 |
}
|
|
@@ -5728,15 +5730,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|
| 5728 |
device const uint8_t * qs = xb->qs + 8*ib32;
|
| 5729 |
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
| 5730 |
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
| 5731 |
-
const float dl = d * (
|
| 5732 |
-
constant uint8_t * grid1 = (constant uint8_t *)(
|
| 5733 |
-
constant uint8_t * grid2 = (constant uint8_t *)(
|
| 5734 |
for (int i = 0; i < 4; ++i) {
|
| 5735 |
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
| 5736 |
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
| 5737 |
}
|
| 5738 |
-
grid1 = (constant uint8_t *)(
|
| 5739 |
-
grid2 = (constant uint8_t *)(
|
| 5740 |
for (int i = 0; i < 4; ++i) {
|
| 5741 |
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
| 5742 |
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
|
|
|
| 4130 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 4131 |
};
|
| 4132 |
|
| 4133 |
+
constexpr constant static uint32_t iq3s_grid[512] = {
|
| 4134 |
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
| 4135 |
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
| 4136 |
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
| 4137 |
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
| 4138 |
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
| 4139 |
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
| 4140 |
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
| 4141 |
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
| 4142 |
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
| 4143 |
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
| 4144 |
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
| 4145 |
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
| 4146 |
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
| 4147 |
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
| 4148 |
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
| 4149 |
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
| 4150 |
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
| 4151 |
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
| 4152 |
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
| 4153 |
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
| 4154 |
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
| 4155 |
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
| 4156 |
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
| 4157 |
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
| 4158 |
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
| 4159 |
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
| 4160 |
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
| 4161 |
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
| 4162 |
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
| 4163 |
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
| 4164 |
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
| 4165 |
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
| 4166 |
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
| 4167 |
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
| 4168 |
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
| 4169 |
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
| 4170 |
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
| 4171 |
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
| 4172 |
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
| 4173 |
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
| 4174 |
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
| 4175 |
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
| 4176 |
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
| 4177 |
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
| 4178 |
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
| 4179 |
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
| 4180 |
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
| 4181 |
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
| 4182 |
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
| 4183 |
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
| 4184 |
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
| 4185 |
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
| 4186 |
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
| 4187 |
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
| 4188 |
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
| 4189 |
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
| 4190 |
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
| 4191 |
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
| 4192 |
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
| 4193 |
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
| 4194 |
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
| 4195 |
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
| 4196 |
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
| 4197 |
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
| 4198 |
};
|
| 4199 |
|
| 4200 |
#define NGRID_IQ1S 512
|
|
|
|
| 4785 |
{
|
| 4786 |
int nval = 8;
|
| 4787 |
int pos = (32*sgitg + tiisg)*nval;
|
| 4788 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
|
| 4789 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4790 |
}
|
| 4791 |
|
|
|
|
| 4812 |
for (int row = 0; row < N_DST; row++) {
|
| 4813 |
|
| 4814 |
const float db = dh[0];
|
| 4815 |
+
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
| 4816 |
|
| 4817 |
float2 sum = {0};
|
| 4818 |
for (int l = 0; l < 4; ++l) {
|
| 4819 |
+
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
| 4820 |
+
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
| 4821 |
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
| 4822 |
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
| 4823 |
for (int j = 0; j < 4; ++j) {
|
| 4824 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
| 4825 |
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
|
|
| 4840 |
for (int row = 0; row < N_DST; ++row) {
|
| 4841 |
all_sum = simd_sum(sumf[row]);
|
| 4842 |
if (tiisg == 0) {
|
| 4843 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
| 4844 |
}
|
| 4845 |
}
|
| 4846 |
}
|
|
|
|
| 5730 |
device const uint8_t * qs = xb->qs + 8*ib32;
|
| 5731 |
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
| 5732 |
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
| 5733 |
+
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
| 5734 |
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
| 5735 |
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
| 5736 |
for (int i = 0; i < 4; ++i) {
|
| 5737 |
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
| 5738 |
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
| 5739 |
}
|
| 5740 |
+
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
| 5741 |
+
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
| 5742 |
for (int i = 0; i < 4; ++i) {
|
| 5743 |
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
| 5744 |
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
ggml-quants.c
CHANGED
|
@@ -3818,71 +3818,71 @@ static const uint32_t iq3xxs_grid[256] = {
|
|
| 3818 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3819 |
};
|
| 3820 |
|
| 3821 |
-
static const uint32_t
|
| 3822 |
-
|
| 3823 |
-
|
| 3824 |
-
|
| 3825 |
-
|
| 3826 |
-
|
| 3827 |
-
|
| 3828 |
-
|
| 3829 |
-
|
| 3830 |
-
|
| 3831 |
-
|
| 3832 |
-
|
| 3833 |
-
|
| 3834 |
-
|
| 3835 |
-
|
| 3836 |
-
|
| 3837 |
-
|
| 3838 |
-
|
| 3839 |
-
|
| 3840 |
-
|
| 3841 |
-
|
| 3842 |
-
|
| 3843 |
-
|
| 3844 |
-
|
| 3845 |
-
|
| 3846 |
-
|
| 3847 |
-
|
| 3848 |
-
|
| 3849 |
-
|
| 3850 |
-
|
| 3851 |
-
|
| 3852 |
-
|
| 3853 |
-
|
| 3854 |
-
|
| 3855 |
-
|
| 3856 |
-
|
| 3857 |
-
|
| 3858 |
-
|
| 3859 |
-
|
| 3860 |
-
|
| 3861 |
-
|
| 3862 |
-
|
| 3863 |
-
|
| 3864 |
-
|
| 3865 |
-
|
| 3866 |
-
|
| 3867 |
-
|
| 3868 |
-
|
| 3869 |
-
|
| 3870 |
-
|
| 3871 |
-
|
| 3872 |
-
|
| 3873 |
-
|
| 3874 |
-
|
| 3875 |
-
|
| 3876 |
-
|
| 3877 |
-
|
| 3878 |
-
|
| 3879 |
-
|
| 3880 |
-
|
| 3881 |
-
|
| 3882 |
-
|
| 3883 |
-
|
| 3884 |
-
|
| 3885 |
-
|
| 3886 |
};
|
| 3887 |
|
| 3888 |
#define NGRID_IQ2XXS 512
|
|
@@ -4162,11 +4162,11 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|
| 4162 |
const uint8_t * signs = x[i].signs;
|
| 4163 |
|
| 4164 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 4165 |
-
const float db1 = d * (
|
| 4166 |
-
const float db2 = d * (
|
| 4167 |
for (int l = 0; l < 4; ++l) {
|
| 4168 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 4169 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 4170 |
for (int j = 0; j < 4; ++j) {
|
| 4171 |
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 4172 |
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
@@ -4176,8 +4176,8 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|
| 4176 |
qs += 8;
|
| 4177 |
signs += 4;
|
| 4178 |
for (int l = 0; l < 4; ++l) {
|
| 4179 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 4180 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 4181 |
for (int j = 0; j < 4; ++j) {
|
| 4182 |
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 4183 |
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
@@ -10089,18 +10089,34 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10089 |
|
| 10090 |
#if defined(__ARM_NEON)
|
| 10091 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10092 |
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
| 10093 |
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
| 10094 |
};
|
| 10095 |
|
| 10096 |
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
| 10097 |
|
| 10098 |
-
const
|
| 10099 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10100 |
|
| 10101 |
uint8x16x2_t vs;
|
| 10102 |
ggml_int8x16x4_t q3s;
|
| 10103 |
ggml_int8x16x4_t q8b;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10104 |
|
| 10105 |
float sumf = 0;
|
| 10106 |
for (int i = 0; i < nb; ++i) {
|
|
@@ -10109,47 +10125,63 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10109 |
const uint8_t * restrict qh = x[i].qh;
|
| 10110 |
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
| 10111 |
const int8_t * restrict q8 = y[i].qs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10112 |
int sumi1 = 0, sumi2 = 0;
|
| 10113 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 10114 |
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
| 10115 |
-
|
| 10116 |
-
|
| 10117 |
-
|
| 10118 |
-
|
| 10119 |
-
|
| 10120 |
-
|
| 10121 |
-
|
| 10122 |
-
|
| 10123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10124 |
|
| 10125 |
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
| 10126 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 10127 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 10128 |
-
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
| 10129 |
-
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
| 10130 |
|
| 10131 |
-
q3s.val[0] =
|
| 10132 |
-
q3s.val[1] =
|
| 10133 |
|
| 10134 |
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
| 10135 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 10136 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 10137 |
-
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
| 10138 |
-
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
| 10139 |
|
| 10140 |
signs += 4;
|
| 10141 |
|
| 10142 |
-
q3s.val[2] =
|
| 10143 |
-
q3s.val[3] =
|
| 10144 |
|
| 10145 |
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
| 10146 |
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10147 |
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
| 10148 |
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
|
|
|
| 10149 |
}
|
| 10150 |
sumf += d*(sumi1 + sumi2);
|
| 10151 |
}
|
| 10152 |
-
*s =
|
| 10153 |
|
| 10154 |
#elif defined(__AVX2__)
|
| 10155 |
|
|
@@ -10164,6 +10196,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10164 |
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
|
| 10165 |
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
|
| 10166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10167 |
__m256 accumf = _mm256_setzero_ps();
|
| 10168 |
for (int i = 0; i < nb; ++i) {
|
| 10169 |
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
@@ -10176,24 +10218,25 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10176 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 10177 |
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 10178 |
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 10179 |
-
const __m256i
|
| 10180 |
-
|
| 10181 |
-
|
| 10182 |
-
|
| 10183 |
-
|
| 10184 |
-
|
| 10185 |
-
|
| 10186 |
-
|
| 10187 |
-
|
| 10188 |
-
const __m256i
|
| 10189 |
-
|
| 10190 |
-
|
| 10191 |
-
|
| 10192 |
-
|
| 10193 |
-
|
| 10194 |
-
|
| 10195 |
-
|
| 10196 |
-
|
|
|
|
| 10197 |
|
| 10198 |
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
| 10199 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
|
@@ -10221,7 +10264,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10221 |
|
| 10222 |
}
|
| 10223 |
|
| 10224 |
-
*s =
|
| 10225 |
|
| 10226 |
#else
|
| 10227 |
|
|
@@ -10238,8 +10281,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10238 |
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
|
| 10239 |
int32_t sumi = 0;
|
| 10240 |
for (int l = 0; l < 4; ++l) {
|
| 10241 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 10242 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 10243 |
for (int j = 0; j < 4; ++j) {
|
| 10244 |
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 10245 |
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
|
@@ -10251,8 +10294,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10251 |
bsum += sumi * ls1;
|
| 10252 |
sumi = 0;
|
| 10253 |
for (int l = 0; l < 4; ++l) {
|
| 10254 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 10255 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 10256 |
for (int j = 0; j < 4; ++j) {
|
| 10257 |
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 10258 |
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
|
@@ -10265,7 +10308,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|
| 10265 |
}
|
| 10266 |
sumf += d * bsum;
|
| 10267 |
}
|
| 10268 |
-
*s =
|
| 10269 |
#endif
|
| 10270 |
}
|
| 10271 |
|
|
@@ -11912,7 +11955,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
|
|
| 11912 |
}
|
| 11913 |
float best = 0;
|
| 11914 |
float scale = max/(2*kMaxQ-1);
|
| 11915 |
-
for (int
|
|
|
|
| 11916 |
float id = (2*kMaxQ-1+is*0.2f)/max;
|
| 11917 |
float this_scale = 1/id;
|
| 11918 |
for (int k = 0; k < bs4; ++k) {
|
|
@@ -11948,7 +11992,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
|
|
| 11948 |
if (n_not_ongrid > 0 && scale > 0) {
|
| 11949 |
float id = 1/scale;
|
| 11950 |
for (int k = 0; k < bs4; ++k) {
|
| 11951 |
-
if (is_on_grid[k]) continue;
|
| 11952 |
uint16_t u = 0;
|
| 11953 |
for (int i = 0; i < 4; ++i) {
|
| 11954 |
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
|
@@ -12004,7 +12048,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
|
|
| 12004 |
}
|
| 12005 |
|
| 12006 |
float d = max_scale/31;
|
| 12007 |
-
y[ibl].d = GGML_FP32_TO_FP16(d);
|
| 12008 |
float id = 1/d;
|
| 12009 |
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
|
| 12010 |
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
|
|
|
|
| 3818 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3819 |
};
|
| 3820 |
|
| 3821 |
+
static const uint32_t iq3s_grid[512] = {
|
| 3822 |
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
| 3823 |
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
| 3824 |
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
| 3825 |
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
| 3826 |
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
| 3827 |
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
| 3828 |
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
| 3829 |
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
| 3830 |
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
| 3831 |
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
| 3832 |
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
| 3833 |
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
| 3834 |
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
| 3835 |
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
| 3836 |
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
| 3837 |
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
| 3838 |
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
| 3839 |
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
| 3840 |
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
| 3841 |
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
| 3842 |
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
| 3843 |
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
| 3844 |
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
| 3845 |
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
| 3846 |
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
| 3847 |
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
| 3848 |
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
| 3849 |
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
| 3850 |
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
| 3851 |
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
| 3852 |
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
| 3853 |
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
| 3854 |
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
| 3855 |
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
| 3856 |
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
| 3857 |
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
| 3858 |
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
| 3859 |
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
| 3860 |
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
| 3861 |
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
| 3862 |
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
| 3863 |
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
| 3864 |
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
| 3865 |
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
| 3866 |
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
| 3867 |
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
| 3868 |
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
| 3869 |
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
| 3870 |
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
| 3871 |
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
| 3872 |
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
| 3873 |
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
| 3874 |
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
| 3875 |
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
| 3876 |
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
| 3877 |
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
| 3878 |
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
| 3879 |
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
| 3880 |
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
| 3881 |
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
| 3882 |
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
| 3883 |
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
| 3884 |
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
| 3885 |
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
| 3886 |
};
|
| 3887 |
|
| 3888 |
#define NGRID_IQ2XXS 512
|
|
|
|
| 4162 |
const uint8_t * signs = x[i].signs;
|
| 4163 |
|
| 4164 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 4165 |
+
const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
| 4166 |
+
const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
| 4167 |
for (int l = 0; l < 4; ++l) {
|
| 4168 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
| 4169 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
| 4170 |
for (int j = 0; j < 4; ++j) {
|
| 4171 |
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 4172 |
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
|
|
| 4176 |
qs += 8;
|
| 4177 |
signs += 4;
|
| 4178 |
for (int l = 0; l < 4; ++l) {
|
| 4179 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
|
| 4180 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
|
| 4181 |
for (int j = 0; j < 4; ++j) {
|
| 4182 |
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 4183 |
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
|
|
| 10089 |
|
| 10090 |
#if defined(__ARM_NEON)
|
| 10091 |
|
| 10092 |
+
typedef union {
|
| 10093 |
+
uint16x8_t vec_index;
|
| 10094 |
+
uint16_t index[8];
|
| 10095 |
+
} vec_index_t;
|
| 10096 |
+
|
| 10097 |
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
| 10098 |
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
| 10099 |
};
|
| 10100 |
|
| 10101 |
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
| 10102 |
|
| 10103 |
+
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
|
| 10104 |
+
|
| 10105 |
+
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
|
| 10106 |
+
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
| 10107 |
+
const int16x8_t hshift = vld1q_s16(k_shift);
|
| 10108 |
+
const uint16x8_t m256 = vdupq_n_u16(256);
|
| 10109 |
+
const uint8x16_t m1 = vdupq_n_u8(1);
|
| 10110 |
|
| 10111 |
uint8x16x2_t vs;
|
| 10112 |
ggml_int8x16x4_t q3s;
|
| 10113 |
ggml_int8x16x4_t q8b;
|
| 10114 |
+
vec_index_t idx;
|
| 10115 |
+
|
| 10116 |
+
#if QK_K == 256
|
| 10117 |
+
uint32_t scales32[2];
|
| 10118 |
+
const uint8_t * scales8 = (const uint8_t *)scales32;
|
| 10119 |
+
#endif
|
| 10120 |
|
| 10121 |
float sumf = 0;
|
| 10122 |
for (int i = 0; i < nb; ++i) {
|
|
|
|
| 10125 |
const uint8_t * restrict qh = x[i].qh;
|
| 10126 |
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
| 10127 |
const int8_t * restrict q8 = y[i].qs;
|
| 10128 |
+
|
| 10129 |
+
#if QK_K == 256
|
| 10130 |
+
memcpy(scales32, x[i].scales, 4);
|
| 10131 |
+
scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
|
| 10132 |
+
scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
|
| 10133 |
+
#endif
|
| 10134 |
+
|
| 10135 |
int sumi1 = 0, sumi2 = 0;
|
| 10136 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 10137 |
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
| 10138 |
+
|
| 10139 |
+
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
|
| 10140 |
+
idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
|
| 10141 |
+
const uint32x4_t aux32x4_0 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
|
| 10142 |
+
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]};
|
| 10143 |
+
const uint32x4_t aux32x4_1 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
|
| 10144 |
+
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]};
|
| 10145 |
+
idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
|
| 10146 |
+
const uint32x4_t aux32x4_2 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
|
| 10147 |
+
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]};
|
| 10148 |
+
const uint32x4_t aux32x4_3 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
|
| 10149 |
+
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]};
|
| 10150 |
+
|
| 10151 |
|
| 10152 |
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
| 10153 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 10154 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 10155 |
+
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
| 10156 |
+
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
| 10157 |
|
| 10158 |
+
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
|
| 10159 |
+
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
|
| 10160 |
|
| 10161 |
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
| 10162 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 10163 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 10164 |
+
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
| 10165 |
+
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
| 10166 |
|
| 10167 |
signs += 4;
|
| 10168 |
|
| 10169 |
+
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
|
| 10170 |
+
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
|
| 10171 |
|
| 10172 |
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
| 10173 |
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
| 10174 |
+
#if QK_K == 256
|
| 10175 |
+
sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
|
| 10176 |
+
sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
|
| 10177 |
+
#else
|
| 10178 |
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
| 10179 |
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
| 10180 |
+
#endif
|
| 10181 |
}
|
| 10182 |
sumf += d*(sumi1 + sumi2);
|
| 10183 |
}
|
| 10184 |
+
*s = sumf;
|
| 10185 |
|
| 10186 |
#elif defined(__AVX2__)
|
| 10187 |
|
|
|
|
| 10196 |
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
|
| 10197 |
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
|
| 10198 |
|
| 10199 |
+
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
| 10200 |
+
const __m256i idx_mask = _mm256_set1_epi32(256);
|
| 10201 |
+
|
| 10202 |
+
typedef union {
|
| 10203 |
+
__m256i vec[2];
|
| 10204 |
+
uint32_t index[16];
|
| 10205 |
+
} index_t;
|
| 10206 |
+
|
| 10207 |
+
index_t idx;
|
| 10208 |
+
|
| 10209 |
__m256 accumf = _mm256_setzero_ps();
|
| 10210 |
for (int i = 0; i < nb; ++i) {
|
| 10211 |
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
|
|
| 10218 |
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 10219 |
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 10220 |
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 10221 |
+
const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
|
| 10222 |
+
idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
|
| 10223 |
+
idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
|
| 10224 |
+
idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
|
| 10225 |
+
idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
|
| 10226 |
+
idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
|
| 10227 |
+
idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
|
| 10228 |
+
|
| 10229 |
+
// At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
|
| 10230 |
+
//const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
|
| 10231 |
+
//const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
|
| 10232 |
+
const __m256i q2_1 = _mm256_set_epi32(
|
| 10233 |
+
iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
|
| 10234 |
+
iq3s_grid[idx.index[3]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
|
| 10235 |
+
);
|
| 10236 |
+
const __m256i q2_2 = _mm256_set_epi32(
|
| 10237 |
+
iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
|
| 10238 |
+
iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
|
| 10239 |
+
);
|
| 10240 |
|
| 10241 |
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
| 10242 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
|
|
|
| 10264 |
|
| 10265 |
}
|
| 10266 |
|
| 10267 |
+
*s = hsum_float_8(accumf);
|
| 10268 |
|
| 10269 |
#else
|
| 10270 |
|
|
|
|
| 10281 |
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
|
| 10282 |
int32_t sumi = 0;
|
| 10283 |
for (int l = 0; l < 4; ++l) {
|
| 10284 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
|
| 10285 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
|
| 10286 |
for (int j = 0; j < 4; ++j) {
|
| 10287 |
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 10288 |
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
|
|
|
| 10294 |
bsum += sumi * ls1;
|
| 10295 |
sumi = 0;
|
| 10296 |
for (int l = 0; l < 4; ++l) {
|
| 10297 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
|
| 10298 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
|
| 10299 |
for (int j = 0; j < 4; ++j) {
|
| 10300 |
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 10301 |
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
|
|
|
| 10308 |
}
|
| 10309 |
sumf += d * bsum;
|
| 10310 |
}
|
| 10311 |
+
*s = sumf;
|
| 10312 |
#endif
|
| 10313 |
}
|
| 10314 |
|
|
|
|
| 11955 |
}
|
| 11956 |
float best = 0;
|
| 11957 |
float scale = max/(2*kMaxQ-1);
|
| 11958 |
+
for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
|
| 11959 |
+
for (int is = -9; is <= 9; ++is) {
|
| 11960 |
float id = (2*kMaxQ-1+is*0.2f)/max;
|
| 11961 |
float this_scale = 1/id;
|
| 11962 |
for (int k = 0; k < bs4; ++k) {
|
|
|
|
| 11992 |
if (n_not_ongrid > 0 && scale > 0) {
|
| 11993 |
float id = 1/scale;
|
| 11994 |
for (int k = 0; k < bs4; ++k) {
|
| 11995 |
+
//if (is_on_grid[k]) continue;
|
| 11996 |
uint16_t u = 0;
|
| 11997 |
for (int i = 0; i < 4; ++i) {
|
| 11998 |
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
|
|
|
| 12048 |
}
|
| 12049 |
|
| 12050 |
float d = max_scale/31;
|
| 12051 |
+
y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);
|
| 12052 |
float id = 1/d;
|
| 12053 |
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
|
| 12054 |
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
|