ggerganov commited on
Commit
f2da2a4
·
unverified ·
1 Parent(s): efed5ba

metal : improve dequantize precision to match CPU (llama/4836)

Browse files
Files changed (1) hide show
  1. ggml-metal.metal +8 -8
ggml-metal.metal CHANGED
@@ -3841,8 +3841,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
3841
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3842
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3843
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3844
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3845
- const half ml = 4.h * dl;
3846
 
3847
  il = (il/2) & 3;
3848
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3909,7 +3909,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3909
  uint8_t ul = 1 << (il/2);
3910
  il = il & 3;
3911
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3912
- const float d = il < 2 ? xb->d : xb->d / 16.h;
3913
  const float min = xb->dmin;
3914
  const float dl = d * sc[0];
3915
  const float ml = min * sc[1];
@@ -3942,17 +3942,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3942
  #if QK_K == 256
3943
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3944
  qh = qh + 32*(il/8) + 16*(il&1);
3945
- half sc = scales[(il%2) + 2 * ((il/2))];
3946
  il = (il/2) & 3;
3947
  #else
3948
  ql = ql + 16 * (il&1);
3949
- half sc = scales[il];
3950
  #endif
3951
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3952
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3953
- const half coef = il>1 ? 1.f/16.h : 1.h;
3954
- const half ml = d_all * sc * 32.h;
3955
- const half dl = d_all * sc * coef;
3956
  for (int i = 0; i < 16; ++i) {
3957
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3958
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
 
3841
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3842
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3843
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3844
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
3845
+ const float ml = 4.f * dl;
3846
 
3847
  il = (il/2) & 3;
3848
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
 
3909
  uint8_t ul = 1 << (il/2);
3910
  il = il & 3;
3911
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3912
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
3913
  const float min = xb->dmin;
3914
  const float dl = d * sc[0];
3915
  const float ml = min * sc[1];
 
3942
  #if QK_K == 256
3943
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3944
  qh = qh + 32*(il/8) + 16*(il&1);
3945
+ float sc = scales[(il%2) + 2 * ((il/2))];
3946
  il = (il/2) & 3;
3947
  #else
3948
  ql = ql + 16 * (il&1);
3949
+ float sc = scales[il];
3950
  #endif
3951
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3952
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3953
+ const float coef = il>1 ? 1.f/16.f : 1.f;
3954
+ const float ml = d_all * sc * 32.f;
3955
+ const float dl = d_all * sc * coef;
3956
  for (int i = 0; i < 16; ++i) {
3957
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3958
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));