Spaces:
Running
Running
metal : improve dequantize precision to match CPU (llama/4836)
Browse files- ggml-metal.metal +8 -8
ggml-metal.metal
CHANGED
|
@@ -3841,8 +3841,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
| 3841 |
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
| 3842 |
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
| 3843 |
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
| 3844 |
-
|
| 3845 |
-
const
|
| 3846 |
|
| 3847 |
il = (il/2) & 3;
|
| 3848 |
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
@@ -3909,7 +3909,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
| 3909 |
uint8_t ul = 1 << (il/2);
|
| 3910 |
il = il & 3;
|
| 3911 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3912 |
-
const float d = il < 2 ? xb->d : xb->d / 16.
|
| 3913 |
const float min = xb->dmin;
|
| 3914 |
const float dl = d * sc[0];
|
| 3915 |
const float ml = min * sc[1];
|
|
@@ -3942,17 +3942,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
| 3942 |
#if QK_K == 256
|
| 3943 |
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
| 3944 |
qh = qh + 32*(il/8) + 16*(il&1);
|
| 3945 |
-
|
| 3946 |
il = (il/2) & 3;
|
| 3947 |
#else
|
| 3948 |
ql = ql + 16 * (il&1);
|
| 3949 |
-
|
| 3950 |
#endif
|
| 3951 |
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 3952 |
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
| 3953 |
-
const
|
| 3954 |
-
const
|
| 3955 |
-
const
|
| 3956 |
for (int i = 0; i < 16; ++i) {
|
| 3957 |
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
| 3958 |
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
|
|
| 3841 |
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
| 3842 |
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
| 3843 |
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
| 3844 |
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
| 3845 |
+
const float ml = 4.f * dl;
|
| 3846 |
|
| 3847 |
il = (il/2) & 3;
|
| 3848 |
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
|
|
| 3909 |
uint8_t ul = 1 << (il/2);
|
| 3910 |
il = il & 3;
|
| 3911 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3912 |
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
| 3913 |
const float min = xb->dmin;
|
| 3914 |
const float dl = d * sc[0];
|
| 3915 |
const float ml = min * sc[1];
|
|
|
|
| 3942 |
#if QK_K == 256
|
| 3943 |
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
| 3944 |
qh = qh + 32*(il/8) + 16*(il&1);
|
| 3945 |
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
| 3946 |
il = (il/2) & 3;
|
| 3947 |
#else
|
| 3948 |
ql = ql + 16 * (il&1);
|
| 3949 |
+
float sc = scales[il];
|
| 3950 |
#endif
|
| 3951 |
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 3952 |
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
| 3953 |
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
| 3954 |
+
const float ml = d_all * sc * 32.f;
|
| 3955 |
+
const float dl = d_all * sc * coef;
|
| 3956 |
for (int i = 0; i < 16; ++i) {
|
| 3957 |
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
| 3958 |
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|