ggerganov commited on
Commit
ebacb3e
·
1 Parent(s): deb934d

ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (llama/14435)

Browse files
ggml/CMakeLists.txt CHANGED
@@ -360,13 +360,6 @@ write_basic_package_version_file(
360
  VERSION ${GGML_INSTALL_VERSION}
361
  COMPATIBILITY SameMajorVersion)
362
 
363
- target_compile_definitions(ggml-base PRIVATE
364
- GGML_VERSION="${GGML_INSTALL_VERSION}"
365
- GGML_COMMIT="${GGML_BUILD_COMMIT}"
366
- )
367
- message(STATUS "ggml version: ${GGML_INSTALL_VERSION}")
368
- message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}")
369
-
370
  install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
371
  ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
372
  DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
 
360
  VERSION ${GGML_INSTALL_VERSION}
361
  COMPATIBILITY SameMajorVersion)
362
 
 
 
 
 
 
 
 
363
  install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
364
  ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
365
  DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
ggml/include/ggml.h CHANGED
@@ -646,9 +646,6 @@ extern "C" {
646
 
647
  // misc
648
 
649
- GGML_API const char * ggml_version(void);
650
- GGML_API const char * ggml_commit(void);
651
-
652
  GGML_API void ggml_time_init(void); // call this once at the beginning of the program
653
  GGML_API int64_t ggml_time_ms(void);
654
  GGML_API int64_t ggml_time_us(void);
@@ -1513,8 +1510,14 @@ extern "C" {
1513
  struct ggml_context * ctx,
1514
  struct ggml_tensor * a);
1515
 
 
 
 
 
 
 
 
1516
  // fused soft_max(a*scale + mask*(ALiBi slope))
1517
- // mask is optional
1518
  // max_bias = 0.0f for no ALiBi
1519
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1520
  struct ggml_context * ctx,
@@ -1977,11 +1980,16 @@ extern "C" {
1977
 
1978
  #define GGML_KQ_MASK_PAD 64
1979
 
1980
- // q: [n_embd_k, n_batch, n_head, 1]
1981
- // k: [n_embd_k, n_kv, n_head_kv, 1]
1982
- // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1983
- // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1984
- // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
 
 
 
 
 
1985
  GGML_API struct ggml_tensor * ggml_flash_attn_ext(
1986
  struct ggml_context * ctx,
1987
  struct ggml_tensor * q,
 
646
 
647
  // misc
648
 
 
 
 
649
  GGML_API void ggml_time_init(void); // call this once at the beginning of the program
650
  GGML_API int64_t ggml_time_ms(void);
651
  GGML_API int64_t ggml_time_us(void);
 
1510
  struct ggml_context * ctx,
1511
  struct ggml_tensor * a);
1512
 
1513
+ // a [ne0, ne01, ne02, ne03]
1514
+ // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
1515
+ //
1516
+ // broadcast:
1517
+ // ne02 % ne12 == 0
1518
+ // ne03 % ne13 == 0
1519
+ //
1520
  // fused soft_max(a*scale + mask*(ALiBi slope))
 
1521
  // max_bias = 0.0f for no ALiBi
1522
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1523
  struct ggml_context * ctx,
 
1980
 
1981
  #define GGML_KQ_MASK_PAD 64
1982
 
1983
+ // q: [n_embd_k, n_batch, n_head, ne3]
1984
+ // k: [n_embd_k, n_kv, n_head_kv, ne3]
1985
+ // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
1986
+ // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1987
+ // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
1988
+ //
1989
+ // broadcast:
1990
+ // n_head % n_head_kv == 0
1991
+ // ne3 % ne32 == 0
1992
+ //
1993
  GGML_API struct ggml_tensor * ggml_flash_attn_ext(
1994
  struct ggml_context * ctx,
1995
  struct ggml_tensor * q,
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -2187,7 +2187,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2187
  case GGML_OP_SQRT:
2188
  case GGML_OP_CLAMP:
2189
  case GGML_OP_DIAG_MASK_INF:
2190
- case GGML_OP_SOFT_MAX:
2191
  case GGML_OP_SUM_ROWS:
2192
  case GGML_OP_ARGSORT:
2193
  case GGML_OP_ACC:
@@ -2205,6 +2204,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2205
  case GGML_OP_PAD_REFLECT_1D:
2206
  case GGML_OP_COUNT_EQUAL:
2207
  return true;
 
 
 
 
2208
  case GGML_OP_FLASH_ATTN_EXT:{
2209
  // derived from [ggml-cuda.cu]
2210
  if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2227,6 +2230,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2227
  // DeepSeek MLA
2228
  return false;
2229
  }
 
 
2230
  if (op->src[0]->ne[3] != 1) {
2231
  return false;
2232
  }
 
2187
  case GGML_OP_SQRT:
2188
  case GGML_OP_CLAMP:
2189
  case GGML_OP_DIAG_MASK_INF:
 
2190
  case GGML_OP_SUM_ROWS:
2191
  case GGML_OP_ARGSORT:
2192
  case GGML_OP_ACC:
 
2204
  case GGML_OP_PAD_REFLECT_1D:
2205
  case GGML_OP_COUNT_EQUAL:
2206
  return true;
2207
+ case GGML_OP_SOFT_MAX:
2208
+ // TODO: support broadcast
2209
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
2210
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
2211
  case GGML_OP_FLASH_ATTN_EXT:{
2212
  // derived from [ggml-cuda.cu]
2213
  if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
 
2230
  // DeepSeek MLA
2231
  return false;
2232
  }
2233
+ // TODO: support broadcast
2234
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
2235
  if (op->src[0]->ne[3] != 1) {
2236
  return false;
2237
  }
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
5232
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5233
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5234
 
5235
- // TODO: handle transposed/permuted matrices
5236
-
5237
  const int ith = params->ith;
5238
  const int nth = params->nth;
5239
 
5240
  GGML_TENSOR_UNARY_OP_LOCALS
5241
 
5242
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
 
 
 
 
 
5243
 
5244
  // TODO: is this supposed to be ceil instead of floor?
5245
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
5249
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5250
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5251
 
5252
- const int nc = src0->ne[0];
5253
- const int nr = ggml_nrows(src0);
5254
-
5255
- // rows per thread
5256
- const int dr = (nr + nth - 1)/nth;
5257
-
5258
- // row range for this thread
5259
- const int ir0 = dr*ith;
5260
- const int ir1 = MIN(ir0 + dr, nr);
5261
-
5262
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5263
 
5264
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5265
 
5266
- for (int i1 = ir0; i1 < ir1; i1++) {
5267
- // ALiBi
5268
- const uint32_t h = (i1/ne01)%ne02; // head
5269
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5270
-
5271
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5272
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5273
-
5274
- // broadcast the mask across rows
5275
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5277
-
5278
- ggml_vec_cpy_f32 (nc, wp, sp);
5279
- ggml_vec_scale_f32(nc, wp, scale);
5280
- if (mp_f32) {
5281
- if (use_f16) {
5282
- for (int i = 0; i < nc; ++i) {
5283
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5284
- }
5285
- } else {
5286
- for (int i = 0; i < nc; ++i) {
5287
- wp[i] += slope*mp_f32[i];
 
 
 
 
 
 
 
 
5288
  }
5289
- }
5290
- }
5291
 
5292
  #ifndef NDEBUG
5293
- for (int i = 0; i < nc; ++i) {
5294
- //printf("p[%d] = %f\n", i, p[i]);
5295
- assert(!isnan(wp[i]));
5296
- }
5297
  #endif
5298
 
5299
- float max = -INFINITY;
5300
- ggml_vec_max_f32(nc, &max, wp);
5301
 
5302
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5303
- assert(sum > 0.0);
5304
 
5305
- sum = 1.0/sum;
5306
- ggml_vec_scale_f32(nc, dp, sum);
5307
 
5308
  #ifndef NDEBUG
5309
- for (int i = 0; i < nc; ++i) {
5310
- assert(!isnan(dp[i]));
5311
- assert(!isinf(dp[i]));
5312
- }
5313
  #endif
 
 
5314
  }
5315
  }
5316
 
@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7766
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7767
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7768
 
7769
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7770
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7771
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7772
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7798
  memset(VKQ32, 0, DV*sizeof(float));
7799
  }
7800
 
7801
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
7802
 
7803
  // k indices
7804
  const int ik3 = iq3 / rk3;
 
5232
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5233
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5234
 
 
 
5235
  const int ith = params->ith;
5236
  const int nth = params->nth;
5237
 
5238
  GGML_TENSOR_UNARY_OP_LOCALS
5239
 
5240
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5241
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5242
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5243
+
5244
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5245
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5246
 
5247
  // TODO: is this supposed to be ceil instead of floor?
5248
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
 
5252
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5253
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5254
 
5255
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
 
 
 
 
 
 
 
 
 
5256
 
5257
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5258
 
5259
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5262
+ const int64_t i11 = i01;
5263
+ const int64_t i12 = i02%ne12;
5264
+ const int64_t i13 = i03%ne13;
5265
+
5266
+ // ALiBi
5267
+ const uint32_t h = i02; // head
5268
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5269
+
5270
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5271
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5272
+
5273
+ // broadcast the mask across rows
5274
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5275
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5276
+
5277
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5278
+ ggml_vec_scale_f32(ne00, wp, scale);
5279
+ if (mp_f32) {
5280
+ if (use_f16) {
5281
+ for (int i = 0; i < ne00; ++i) {
5282
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5283
+ }
5284
+ } else {
5285
+ for (int i = 0; i < ne00; ++i) {
5286
+ wp[i] += slope*mp_f32[i];
5287
+ }
5288
+ }
5289
  }
 
 
5290
 
5291
  #ifndef NDEBUG
5292
+ for (int i = 0; i < ne00; ++i) {
5293
+ //printf("p[%d] = %f\n", i, p[i]);
5294
+ assert(!isnan(wp[i]));
5295
+ }
5296
  #endif
5297
 
5298
+ float max = -INFINITY;
5299
+ ggml_vec_max_f32(ne00, &max, wp);
5300
 
5301
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5302
+ assert(sum > 0.0);
5303
 
5304
+ sum = 1.0/sum;
5305
+ ggml_vec_scale_f32(ne00, dp, sum);
5306
 
5307
  #ifndef NDEBUG
5308
+ for (int i = 0; i < ne00; ++i) {
5309
+ assert(!isnan(dp[i]));
5310
+ assert(!isinf(dp[i]));
5311
+ }
5312
  #endif
5313
+ }
5314
+ }
5315
  }
5316
  }
5317
 
 
7767
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7768
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7769
 
7770
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7771
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7772
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7773
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
 
7799
  memset(VKQ32, 0, DV*sizeof(float));
7800
  }
7801
 
7802
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
7803
 
7804
  // k indices
7805
  const int ik3 = iq3 / rk3;
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3327,8 +3327,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3327
  case GGML_OP_CONT:
3328
  return op->src[0]->type != GGML_TYPE_BF16;
3329
  case GGML_OP_DIAG_MASK_INF:
3330
- case GGML_OP_SOFT_MAX:
3331
  return true;
 
 
 
 
 
 
 
 
3332
  case GGML_OP_SOFT_MAX_BACK: {
3333
  float max_bias = 0.0f;
3334
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
@@ -3375,6 +3382,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3375
  if (op->src[0]->ne[0] == 192) {
3376
  return false;
3377
  }
 
 
3378
  if (op->src[0]->ne[3] != 1) {
3379
  return false;
3380
  }
 
3327
  case GGML_OP_CONT:
3328
  return op->src[0]->type != GGML_TYPE_BF16;
3329
  case GGML_OP_DIAG_MASK_INF:
 
3330
  return true;
3331
+ case GGML_OP_SOFT_MAX:
3332
+ // TODO: support batching
3333
+ if (op->src[0]->ne[3] != 1) {
3334
+ return false;
3335
+ }
3336
+ // TODO: support broadcast
3337
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
3338
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
3339
  case GGML_OP_SOFT_MAX_BACK: {
3340
  float max_bias = 0.0f;
3341
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
 
3382
  if (op->src[0]->ne[0] == 192) {
3383
  return false;
3384
  }
3385
+ // TODO: support broadcast
3386
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
3387
  if (op->src[0]->ne[3] != 1) {
3388
  return false;
3389
  }
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -229,7 +229,9 @@ typedef struct {
229
  uint64_t nb21;
230
  uint64_t nb22;
231
  uint64_t nb23;
 
232
  uint64_t nb31;
 
233
  int32_t ne1;
234
  int32_t ne2;
235
  float scale;
@@ -461,9 +463,21 @@ typedef struct {
461
  } ggml_metal_kargs_sum_rows;
462
 
463
  typedef struct {
464
- int64_t ne00;
465
- int64_t ne01;
466
- int64_t ne02;
 
 
 
 
 
 
 
 
 
 
 
 
467
  float scale;
468
  float max_bias;
469
  float m0;
 
229
  uint64_t nb21;
230
  uint64_t nb22;
231
  uint64_t nb23;
232
+ int32_t ne32;
233
  uint64_t nb31;
234
+ uint64_t nb32;
235
  int32_t ne1;
236
  int32_t ne2;
237
  float scale;
 
463
  } ggml_metal_kargs_sum_rows;
464
 
465
  typedef struct {
466
+ int32_t ne00;
467
+ int32_t ne01;
468
+ int32_t ne02;
469
+ uint64_t nb01;
470
+ uint64_t nb02;
471
+ uint64_t nb03;
472
+ int32_t ne11;
473
+ int32_t ne12;
474
+ int32_t ne13;
475
+ uint64_t nb11;
476
+ uint64_t nb12;
477
+ uint64_t nb13;
478
+ uint64_t nb1;
479
+ uint64_t nb2;
480
+ uint64_t nb3;
481
  float scale;
482
  float max_bias;
483
  float m0;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -1725,7 +1725,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1725
  case GGML_OP_MEAN:
1726
  case GGML_OP_SOFT_MAX:
1727
  case GGML_OP_GROUP_NORM:
1728
- return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1729
  case GGML_OP_RMS_NORM:
1730
  case GGML_OP_L2_NORM:
1731
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -2644,10 +2644,7 @@ static bool ggml_metal_encode_node(
2644
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2645
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2646
 
2647
- const int64_t nrows_x = ggml_nrows(src0);
2648
- const int64_t nrows_y = src0->ne[1];
2649
-
2650
- const uint32_t n_head = nrows_x/nrows_y;
2651
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2652
 
2653
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2707,6 +2704,18 @@ static bool ggml_metal_encode_node(
2707
  /*.ne00 =*/ ne00,
2708
  /*.ne01 =*/ ne01,
2709
  /*.ne02 =*/ ne02,
 
 
 
 
 
 
 
 
 
 
 
 
2710
  /*.scale =*/ scale,
2711
  /*.max_bias =*/ max_bias,
2712
  /*.m0 =*/ m0,
@@ -2726,7 +2735,7 @@ static bool ggml_metal_encode_node(
2726
 
2727
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2728
 
2729
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2730
  } break;
2731
  case GGML_OP_DIAG_MASK_INF:
2732
  {
@@ -4979,7 +4988,9 @@ static bool ggml_metal_encode_node(
4979
  /*.nb21 =*/ nb21,
4980
  /*.nb22 =*/ nb22,
4981
  /*.nb23 =*/ nb23,
 
4982
  /*.nb31 =*/ nb31,
 
4983
  /*.ne1 =*/ ne1,
4984
  /*.ne2 =*/ ne2,
4985
  /*.scale =*/ scale,
 
1725
  case GGML_OP_MEAN:
1726
  case GGML_OP_SOFT_MAX:
1727
  case GGML_OP_GROUP_NORM:
1728
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1729
  case GGML_OP_RMS_NORM:
1730
  case GGML_OP_L2_NORM:
1731
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
 
2644
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2645
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2646
 
2647
+ const uint32_t n_head = src0->ne[2];
 
 
 
2648
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2649
 
2650
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
 
2704
  /*.ne00 =*/ ne00,
2705
  /*.ne01 =*/ ne01,
2706
  /*.ne02 =*/ ne02,
2707
+ /*.nb01 =*/ nb01,
2708
+ /*.nb02 =*/ nb02,
2709
+ /*.nb03 =*/ nb03,
2710
+ /*.ne11 =*/ ne11,
2711
+ /*.ne12 =*/ ne12,
2712
+ /*.ne13 =*/ ne13,
2713
+ /*.nb11 =*/ nb11,
2714
+ /*.nb12 =*/ nb12,
2715
+ /*.nb13 =*/ nb13,
2716
+ /*.nb1 =*/ nb1,
2717
+ /*.nb2 =*/ nb2,
2718
+ /*.nb3 =*/ nb3,
2719
  /*.scale =*/ scale,
2720
  /*.max_bias =*/ max_bias,
2721
  /*.m0 =*/ m0,
 
2735
 
2736
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2737
 
2738
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2739
  } break;
2740
  case GGML_OP_DIAG_MASK_INF:
2741
  {
 
4988
  /*.nb21 =*/ nb21,
4989
  /*.nb22 =*/ nb22,
4990
  /*.nb23 =*/ nb23,
4991
+ /*.ne32 =*/ ne32,
4992
  /*.nb31 =*/ nb31,
4993
+ /*.nb32 =*/ nb32,
4994
  /*.ne1 =*/ ne1,
4995
  /*.ne2 =*/ ne2,
4996
  /*.scale =*/ scale,
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1320,24 +1320,28 @@ kernel void kernel_soft_max(
1320
  device char * dst,
1321
  constant ggml_metal_kargs_soft_max & args,
1322
  threadgroup float * buf [[threadgroup(0)]],
1323
- uint tgpig[[threadgroup_position_in_grid]],
1324
- uint tpitg[[thread_position_in_threadgroup]],
1325
  uint sgitg[[simdgroup_index_in_threadgroup]],
1326
  uint tiisg[[thread_index_in_simdgroup]],
1327
- uint ntg[[threads_per_threadgroup]]) {
1328
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1329
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1330
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
 
 
 
1331
 
1332
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1333
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1334
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1335
 
1336
  float slope = 1.0f;
1337
 
1338
  // ALiBi
1339
  if (args.max_bias > 0.0f) {
1340
- const int64_t h = i02;
1341
 
1342
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1343
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1348,13 +1352,13 @@ kernel void kernel_soft_max(
1348
  // parallel max
1349
  float lmax = -INFINITY;
1350
 
1351
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1352
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1353
  }
1354
 
1355
  // find the max value in the block
1356
  float max_val = simd_max(lmax);
1357
- if (ntg > N_SIMDWIDTH) {
1358
  if (sgitg == 0) {
1359
  buf[tiisg] = -INFINITY;
1360
  }
@@ -1373,7 +1377,7 @@ kernel void kernel_soft_max(
1373
 
1374
  // parallel sum
1375
  float lsum = 0.0f;
1376
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1377
  const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1378
  lsum += exp_psrc0;
1379
  pdst[i00] = exp_psrc0;
@@ -1385,7 +1389,7 @@ kernel void kernel_soft_max(
1385
 
1386
  float sum = simd_sum(lsum);
1387
 
1388
- if (ntg > N_SIMDWIDTH) {
1389
  if (sgitg == 0) {
1390
  buf[tiisg] = 0.0f;
1391
  }
@@ -1404,7 +1408,7 @@ kernel void kernel_soft_max(
1404
 
1405
  const float inv_sum = 1.0f/sum;
1406
 
1407
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1408
  pdst[i00] *= inv_sum;
1409
  }
1410
  }
@@ -1416,23 +1420,27 @@ kernel void kernel_soft_max_4(
1416
  device char * dst,
1417
  constant ggml_metal_kargs_soft_max & args,
1418
  threadgroup float * buf [[threadgroup(0)]],
1419
- uint tgpig[[threadgroup_position_in_grid]],
1420
- uint tpitg[[thread_position_in_threadgroup]],
1421
  uint sgitg[[simdgroup_index_in_threadgroup]],
1422
  uint tiisg[[thread_index_in_simdgroup]],
1423
- uint ntg[[threads_per_threadgroup]]) {
1424
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1425
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1426
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
 
 
 
1427
 
1428
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1429
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1430
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1431
 
1432
  float slope = 1.0f;
1433
 
1434
  if (args.max_bias > 0.0f) {
1435
- const int64_t h = i02;
1436
 
1437
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1438
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1443,14 +1451,14 @@ kernel void kernel_soft_max_4(
1443
  // parallel max
1444
  float4 lmax4 = -INFINITY;
1445
 
1446
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1447
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1448
  }
1449
 
1450
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1451
 
1452
  float max_val = simd_max(lmax);
1453
- if (ntg > N_SIMDWIDTH) {
1454
  if (sgitg == 0) {
1455
  buf[tiisg] = -INFINITY;
1456
  }
@@ -1469,7 +1477,7 @@ kernel void kernel_soft_max_4(
1469
 
1470
  // parallel sum
1471
  float4 lsum4 = 0.0f;
1472
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1473
  const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1474
  lsum4 += exp_psrc4;
1475
  pdst4[i00] = exp_psrc4;
@@ -1483,7 +1491,7 @@ kernel void kernel_soft_max_4(
1483
 
1484
  float sum = simd_sum(lsum);
1485
 
1486
- if (ntg > N_SIMDWIDTH) {
1487
  if (sgitg == 0) {
1488
  buf[tiisg] = 0.0f;
1489
  }
@@ -1502,7 +1510,7 @@ kernel void kernel_soft_max_4(
1502
 
1503
  const float inv_sum = 1.0f/sum;
1504
 
1505
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1506
  pdst4[i00] *= inv_sum;
1507
  }
1508
  }
@@ -3776,7 +3784,7 @@ kernel void kernel_flash_attn_ext(
3776
  // load the mask in shared memory
3777
  #pragma unroll(Q)
3778
  for (short j = 0; j < Q; ++j) {
3779
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3780
 
3781
  const float m = pm[ic + tiisg];
3782
 
@@ -4262,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec(
4262
  const bool has_mask = mask != q;
4263
 
4264
  // pointer to the mask
4265
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
4266
 
4267
  float slope = 1.0f;
4268
 
 
1320
  device char * dst,
1321
  constant ggml_metal_kargs_soft_max & args,
1322
  threadgroup float * buf [[threadgroup(0)]],
1323
+ uint3 tgpig[[threadgroup_position_in_grid]],
1324
+ uint3 tpitg[[thread_position_in_threadgroup]],
1325
  uint sgitg[[simdgroup_index_in_threadgroup]],
1326
  uint tiisg[[thread_index_in_simdgroup]],
1327
+ uint3 tptg[[threads_per_threadgroup]]) {
1328
+ const int32_t i03 = tgpig.z;
1329
+ const int32_t i02 = tgpig.y;
1330
+ const int32_t i01 = tgpig.x;
1331
+
1332
+ const int32_t i13 = i03%args.ne13;
1333
+ const int32_t i12 = i02%args.ne12;
1334
+ const int32_t i11 = i01;
1335
 
1336
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1337
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1338
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1339
 
1340
  float slope = 1.0f;
1341
 
1342
  // ALiBi
1343
  if (args.max_bias > 0.0f) {
1344
+ const int32_t h = i02;
1345
 
1346
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1347
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
1352
  // parallel max
1353
  float lmax = -INFINITY;
1354
 
1355
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1356
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1357
  }
1358
 
1359
  // find the max value in the block
1360
  float max_val = simd_max(lmax);
1361
+ if (tptg.x > N_SIMDWIDTH) {
1362
  if (sgitg == 0) {
1363
  buf[tiisg] = -INFINITY;
1364
  }
 
1377
 
1378
  // parallel sum
1379
  float lsum = 0.0f;
1380
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1381
  const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1382
  lsum += exp_psrc0;
1383
  pdst[i00] = exp_psrc0;
 
1389
 
1390
  float sum = simd_sum(lsum);
1391
 
1392
+ if (tptg.x > N_SIMDWIDTH) {
1393
  if (sgitg == 0) {
1394
  buf[tiisg] = 0.0f;
1395
  }
 
1408
 
1409
  const float inv_sum = 1.0f/sum;
1410
 
1411
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1412
  pdst[i00] *= inv_sum;
1413
  }
1414
  }
 
1420
  device char * dst,
1421
  constant ggml_metal_kargs_soft_max & args,
1422
  threadgroup float * buf [[threadgroup(0)]],
1423
+ uint3 tgpig[[threadgroup_position_in_grid]],
1424
+ uint3 tpitg[[thread_position_in_threadgroup]],
1425
  uint sgitg[[simdgroup_index_in_threadgroup]],
1426
  uint tiisg[[thread_index_in_simdgroup]],
1427
+ uint3 tptg[[threads_per_threadgroup]]) {
1428
+ const int32_t i03 = tgpig.z;
1429
+ const int32_t i02 = tgpig.y;
1430
+ const int32_t i01 = tgpig.x;
1431
+
1432
+ const int32_t i13 = i03%args.ne13;
1433
+ const int32_t i12 = i02%args.ne12;
1434
+ const int32_t i11 = i01;
1435
 
1436
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1437
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1438
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1439
 
1440
  float slope = 1.0f;
1441
 
1442
  if (args.max_bias > 0.0f) {
1443
+ const int32_t h = i02;
1444
 
1445
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1446
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
1451
  // parallel max
1452
  float4 lmax4 = -INFINITY;
1453
 
1454
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1455
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1456
  }
1457
 
1458
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1459
 
1460
  float max_val = simd_max(lmax);
1461
+ if (tptg.x > N_SIMDWIDTH) {
1462
  if (sgitg == 0) {
1463
  buf[tiisg] = -INFINITY;
1464
  }
 
1477
 
1478
  // parallel sum
1479
  float4 lsum4 = 0.0f;
1480
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1481
  const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1482
  lsum4 += exp_psrc4;
1483
  pdst4[i00] = exp_psrc4;
 
1491
 
1492
  float sum = simd_sum(lsum);
1493
 
1494
+ if (tptg.x > N_SIMDWIDTH) {
1495
  if (sgitg == 0) {
1496
  buf[tiisg] = 0.0f;
1497
  }
 
1510
 
1511
  const float inv_sum = 1.0f/sum;
1512
 
1513
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1514
  pdst4[i00] *= inv_sum;
1515
  }
1516
  }
 
3784
  // load the mask in shared memory
3785
  #pragma unroll(Q)
3786
  for (short j = 0; j < Q; ++j) {
3787
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
3788
 
3789
  const float m = pm[ic + tiisg];
3790
 
 
4270
  const bool has_mask = mask != q;
4271
 
4272
  // pointer to the mask
4273
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
4274
 
4275
  float slope = 1.0f;
4276
 
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -4395,9 +4395,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4395
  return true;
4396
  case GGML_OP_CONT:
4397
  return op->src[0]->type != GGML_TYPE_BF16;
4398
- case GGML_OP_DIAG_MASK_INF:
4399
  case GGML_OP_SOFT_MAX:
4400
- return true;
 
 
 
 
 
 
 
4401
  case GGML_OP_ROPE:
4402
  case GGML_OP_IM2COL:
4403
  return true;
 
4395
  return true;
4396
  case GGML_OP_CONT:
4397
  return op->src[0]->type != GGML_TYPE_BF16;
 
4398
  case GGML_OP_SOFT_MAX:
4399
+ // TODO: support batching
4400
+ if (op->src[0]->ne[3] != 1) {
4401
+ return false;
4402
+ }
4403
+ // TODO: support broadcast
4404
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4405
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
4406
+ case GGML_OP_DIAG_MASK_INF:
4407
  case GGML_OP_ROPE:
4408
  case GGML_OP_IM2COL:
4409
  return true;
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -410,14 +410,13 @@ struct vk_device_struct {
410
  vk_pipeline pipeline_div_norepeat[2][2][2];
411
 
412
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
- vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
414
  vk_pipeline pipeline_scale_f32;
415
  vk_pipeline pipeline_sqr_f32;
416
  vk_pipeline pipeline_sin_f32;
417
  vk_pipeline pipeline_cos_f32;
418
  vk_pipeline pipeline_clamp_f32;
419
  vk_pipeline pipeline_pad_f32;
420
- vk_pipeline pipeline_roll_f32;
421
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
422
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
423
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
@@ -689,37 +688,6 @@ struct vk_op_unary_push_constants {
689
  };
690
  static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
691
 
692
- static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
693
- GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
694
- ne = ne != 0 ? ne : ggml_nelements(dst);
695
- GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
696
-
697
- vk_op_unary_push_constants p{};
698
- p.ne = (uint32_t)ne;
699
-
700
- size_t src0_tsize = ggml_type_size(src0->type);
701
- p.ne00 = (uint32_t)src0->ne[0];
702
- p.ne01 = (uint32_t)src0->ne[1];
703
- p.ne02 = (uint32_t)src0->ne[2];
704
- p.ne03 = (uint32_t)src0->ne[3];
705
- p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
706
- p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
707
- p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
708
- p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
709
-
710
- size_t dst_tsize = ggml_type_size(dst->type);
711
- p.ne10 = (uint32_t)dst->ne[0];
712
- p.ne11 = (uint32_t)dst->ne[1];
713
- p.ne12 = (uint32_t)dst->ne[2];
714
- p.ne13 = (uint32_t)dst->ne[3];
715
- p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
716
- p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
717
- p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
718
- p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
719
-
720
- return p; // fastdiv values and offsets are initialized later in ggml_vk_op
721
- }
722
-
723
  // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
724
  // Precompute mp (m' in the paper) and L such that division
725
  // can be computed using a multiply (high 32b of 64b result)
@@ -881,7 +849,6 @@ struct vk_op_conv2d_dw_push_constants {
881
 
882
  struct vk_op_upscale_push_constants {
883
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
884
- uint32_t ne00; uint32_t ne01;
885
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
886
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
887
  float sf0; float sf1; float sf2; float sf3;
@@ -2775,9 +2742,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2775
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2776
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2777
 
2778
- ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2779
- ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2780
- ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2781
 
2782
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2783
 
@@ -2789,8 +2754,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
2789
 
2790
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2791
 
2792
- ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2793
-
2794
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2795
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2796
 
@@ -6453,16 +6416,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6453
  }
6454
  return nullptr;
6455
  case GGML_OP_UPSCALE:
6456
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6457
- int mode = ggml_get_op_params_i32(dst, 0);
6458
- switch (mode) {
6459
- case GGML_SCALE_MODE_NEAREST:
6460
- return ctx->device->pipeline_upscale_nearest_f32;
6461
- case GGML_SCALE_MODE_BILINEAR:
6462
- return ctx->device->pipeline_upscale_bilinear_f32;
6463
- case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6464
- return ctx->device->pipeline_upscale_bilinear_ac_f32;
6465
- }
6466
  }
6467
  return nullptr;
6468
  case GGML_OP_SCALE:
@@ -6495,11 +6450,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6495
  return ctx->device->pipeline_pad_f32;
6496
  }
6497
  return nullptr;
6498
- case GGML_OP_ROLL:
6499
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6500
- return ctx->device->pipeline_roll_f32;
6501
- }
6502
- return nullptr;
6503
  case GGML_OP_REPEAT:
6504
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
6505
  return ctx->device->pipeline_repeat_f32;
@@ -7042,7 +6992,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7042
  case GGML_OP_COS:
7043
  case GGML_OP_CLAMP:
7044
  case GGML_OP_PAD:
7045
- case GGML_OP_ROLL:
7046
  case GGML_OP_REPEAT:
7047
  case GGML_OP_REPEAT_BACK:
7048
  case GGML_OP_CPY:
@@ -7479,21 +7428,14 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7479
 
7480
  static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7481
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7482
- const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7483
 
7484
- float sf0 = (float)dst->ne[0] / src0->ne[0];
7485
- float sf1 = (float)dst->ne[1] / src0->ne[1];
7486
- float sf2 = (float)dst->ne[2] / src0->ne[2];
7487
- float sf3 = (float)dst->ne[3] / src0->ne[3];
7488
-
7489
- if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7490
- sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7491
- sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7492
- }
7493
 
7494
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7495
  (uint32_t)ggml_nelements(dst), 0, 0,
7496
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7497
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7498
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7499
  sf0, sf1, sf2, sf3,
@@ -7501,60 +7443,117 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
7501
  }
7502
 
7503
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7504
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7505
- p.param1 = ggml_get_op_params_f32(dst, 0);
 
7506
 
7507
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
 
 
 
 
 
 
 
7508
  }
7509
 
7510
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7511
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
 
 
 
 
 
 
 
 
 
 
7512
  }
7513
 
7514
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7515
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
 
 
 
 
 
 
 
 
 
 
7516
  }
7517
 
7518
  static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7519
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
 
 
 
 
 
 
 
 
 
 
7520
  }
7521
 
7522
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7523
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7524
- p.param1 = ggml_get_op_params_f32(dst, 0);
7525
- p.param2 = ggml_get_op_params_f32(dst, 1);
7526
 
7527
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
 
 
 
 
 
 
 
7528
  }
7529
 
7530
  static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7531
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7532
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7533
- }
7534
-
7535
- static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7536
- const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7537
- const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7538
- const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7539
- const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7540
- const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7541
- const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7542
-
7543
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7544
- memcpy(&p.param1, &s01_packed, sizeof(float));
7545
- memcpy(&p.param2, &s23_packed, sizeof(float));
7546
 
7547
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
 
 
 
 
 
 
 
7548
  }
7549
 
7550
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7551
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7552
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
 
 
 
 
 
 
 
 
 
7553
  }
7554
 
7555
  static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7556
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7557
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
 
 
 
 
 
 
 
 
 
7558
  }
7559
 
7560
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -7572,8 +7571,14 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7572
  }
7573
  }
7574
 
7575
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7576
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
 
 
 
 
 
 
7577
  }
7578
 
7579
  static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8885,7 +8890,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8885
  case GGML_OP_COS:
8886
  case GGML_OP_CLAMP:
8887
  case GGML_OP_PAD:
8888
- case GGML_OP_ROLL:
8889
  case GGML_OP_CPY:
8890
  case GGML_OP_CONT:
8891
  case GGML_OP_DUP:
@@ -9055,10 +9059,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9055
  case GGML_OP_PAD:
9056
  ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
9057
 
9058
- break;
9059
- case GGML_OP_ROLL:
9060
- ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9061
-
9062
  break;
9063
  case GGML_OP_CPY:
9064
  case GGML_OP_CONT:
@@ -9276,7 +9276,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9276
  case GGML_OP_COS:
9277
  case GGML_OP_CLAMP:
9278
  case GGML_OP_PAD:
9279
- case GGML_OP_ROLL:
9280
  case GGML_OP_CPY:
9281
  case GGML_OP_CONT:
9282
  case GGML_OP_DUP:
@@ -10249,6 +10248,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10249
  if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10250
  return false;
10251
  }
 
 
 
 
 
10252
  // It's straightforward to support different K/V dequant, but would
10253
  // significantly increase the number of pipelines
10254
  if (op->src[1]->type != op->src[2]->type) {
@@ -10401,13 +10405,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10401
  case GGML_OP_CLAMP:
10402
  return op->src[0]->type == GGML_TYPE_F32;
10403
  case GGML_OP_UPSCALE:
 
10404
  case GGML_OP_ACC:
10405
  case GGML_OP_CONCAT:
10406
  case GGML_OP_SCALE:
10407
  case GGML_OP_PAD:
10408
- case GGML_OP_ROLL:
10409
  case GGML_OP_DIAG_MASK_INF:
 
10410
  case GGML_OP_SOFT_MAX:
 
 
 
 
 
 
 
10411
  case GGML_OP_SOFT_MAX_BACK:
10412
  case GGML_OP_ARGSORT:
10413
  case GGML_OP_SUM:
 
410
  vk_pipeline pipeline_div_norepeat[2][2][2];
411
 
412
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
+ vk_pipeline pipeline_upscale_f32;
414
  vk_pipeline pipeline_scale_f32;
415
  vk_pipeline pipeline_sqr_f32;
416
  vk_pipeline pipeline_sin_f32;
417
  vk_pipeline pipeline_cos_f32;
418
  vk_pipeline pipeline_clamp_f32;
419
  vk_pipeline pipeline_pad_f32;
 
420
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
421
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
422
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
 
688
  };
689
  static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
692
  // Precompute mp (m' in the paper) and L such that division
693
  // can be computed using a multiply (high 32b of 64b result)
 
849
 
850
  struct vk_op_upscale_push_constants {
851
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
 
852
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
853
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
854
  float sf0; float sf1; float sf2; float sf3;
 
2742
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2743
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2744
 
2745
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
 
 
2746
 
2747
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2748
 
 
2754
 
2755
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2756
 
 
 
2757
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2758
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2759
 
 
6416
  }
6417
  return nullptr;
6418
  case GGML_OP_UPSCALE:
6419
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6420
+ return ctx->device->pipeline_upscale_f32;
 
 
 
 
 
 
 
 
6421
  }
6422
  return nullptr;
6423
  case GGML_OP_SCALE:
 
6450
  return ctx->device->pipeline_pad_f32;
6451
  }
6452
  return nullptr;
 
 
 
 
 
6453
  case GGML_OP_REPEAT:
6454
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
6455
  return ctx->device->pipeline_repeat_f32;
 
6992
  case GGML_OP_COS:
6993
  case GGML_OP_CLAMP:
6994
  case GGML_OP_PAD:
 
6995
  case GGML_OP_REPEAT:
6996
  case GGML_OP_REPEAT_BACK:
6997
  case GGML_OP_CPY:
 
7428
 
7429
  static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7430
  const uint32_t src0_type_size = ggml_type_size(src0->type);
 
7431
 
7432
+ const float sf0 = (float)dst->ne[0] / src0->ne[0];
7433
+ const float sf1 = (float)dst->ne[1] / src0->ne[1];
7434
+ const float sf2 = (float)dst->ne[2] / src0->ne[2];
7435
+ const float sf3 = (float)dst->ne[3] / src0->ne[3];
 
 
 
 
 
7436
 
7437
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7438
  (uint32_t)ggml_nelements(dst), 0, 0,
 
7439
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7440
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7441
  sf0, sf1, sf2, sf3,
 
7443
  }
7444
 
7445
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7446
+ float * op_params = (float *)dst->op_params;
7447
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7448
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7449
 
7450
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7451
+ (uint32_t)ggml_nelements(src0),
7452
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7453
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7454
+ 0,
7455
+ op_params[0], 0.0f,
7456
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7457
+ }, dryrun);
7458
  }
7459
 
7460
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7461
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7462
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7463
+
7464
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7465
+ (uint32_t)ggml_nelements(src0),
7466
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7467
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7468
+ 0,
7469
+ 0.0f, 0.0f,
7470
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7471
+ }, dryrun);
7472
  }
7473
 
7474
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7475
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7476
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7477
+
7478
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7479
+ (uint32_t)ggml_nelements(src0),
7480
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7481
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7482
+ 0,
7483
+ 0.0f, 0.0f,
7484
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7485
+ }, dryrun);
7486
  }
7487
 
7488
  static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7489
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7490
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7491
+
7492
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7493
+ (uint32_t)ggml_nelements(src0),
7494
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7495
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7496
+ 0,
7497
+ 0.0f, 0.0f,
7498
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7499
+ }, dryrun);
7500
  }
7501
 
7502
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7503
+ float * op_params = (float *)dst->op_params;
7504
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7505
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7506
 
7507
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7508
+ (uint32_t)ggml_nelements(src0),
7509
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7510
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7511
+ 0,
7512
+ op_params[0], op_params[1],
7513
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7514
+ }, dryrun);
7515
  }
7516
 
7517
  static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7518
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7519
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
 
 
 
 
 
 
 
 
 
 
 
 
 
7520
 
7521
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7522
+ (uint32_t)ggml_nelements(dst),
7523
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7524
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7525
+ 0,
7526
+ 0.0f, 0.0f,
7527
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7528
+ }, dryrun);
7529
  }
7530
 
7531
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7532
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7533
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7534
+
7535
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7536
+ (uint32_t)ggml_nelements(dst),
7537
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7538
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7539
+ 0,
7540
+ 0.0f, 0.0f,
7541
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7542
+ }, dryrun);
7543
  }
7544
 
7545
  static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7546
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7547
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7548
+
7549
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7550
+ (uint32_t)ggml_nelements(dst),
7551
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7552
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7553
+ 0,
7554
+ 0.0f, 0.0f,
7555
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7556
+ }, dryrun);
7557
  }
7558
 
7559
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
 
7571
  }
7572
  }
7573
 
7574
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7575
+ ne,
7576
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7577
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7578
+ 0,
7579
+ 0.0f, 0.0f,
7580
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7581
+ }, dryrun);
7582
  }
7583
 
7584
  static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
 
8890
  case GGML_OP_COS:
8891
  case GGML_OP_CLAMP:
8892
  case GGML_OP_PAD:
 
8893
  case GGML_OP_CPY:
8894
  case GGML_OP_CONT:
8895
  case GGML_OP_DUP:
 
9059
  case GGML_OP_PAD:
9060
  ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
9061
 
 
 
 
 
9062
  break;
9063
  case GGML_OP_CPY:
9064
  case GGML_OP_CONT:
 
9276
  case GGML_OP_COS:
9277
  case GGML_OP_CLAMP:
9278
  case GGML_OP_PAD:
 
9279
  case GGML_OP_CPY:
9280
  case GGML_OP_CONT:
9281
  case GGML_OP_DUP:
 
10248
  if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10249
  return false;
10250
  }
10251
+ // TODO: support broadcast
10252
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10253
+ if (op->src[0]->ne[3] != 1) {
10254
+ return false;
10255
+ }
10256
  // It's straightforward to support different K/V dequant, but would
10257
  // significantly increase the number of pipelines
10258
  if (op->src[1]->type != op->src[2]->type) {
 
10405
  case GGML_OP_CLAMP:
10406
  return op->src[0]->type == GGML_TYPE_F32;
10407
  case GGML_OP_UPSCALE:
10408
+ return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10409
  case GGML_OP_ACC:
10410
  case GGML_OP_CONCAT:
10411
  case GGML_OP_SCALE:
10412
  case GGML_OP_PAD:
 
10413
  case GGML_OP_DIAG_MASK_INF:
10414
+ return true;
10415
  case GGML_OP_SOFT_MAX:
10416
+ // TODO: support batching
10417
+ if (op->src[0]->ne[3] != 1) {
10418
+ return false;
10419
+ }
10420
+ // TODO: support broadcast
10421
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10422
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
10423
  case GGML_OP_SOFT_MAX_BACK:
10424
  case GGML_OP_ARGSORT:
10425
  case GGML_OP_SUM:
ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp CHANGED
@@ -3,7 +3,6 @@
3
  layout (push_constant) uniform parameter
4
  {
5
  uint ne; uint a_offset; uint d_offset;
6
- uint ne00; uint ne01;
7
  uint nb00; uint nb01; uint nb02; uint nb03;
8
  uint ne10; uint ne11; uint ne12; uint ne13;
9
  float sf0; float sf1; float sf2; float sf3;
@@ -16,61 +15,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
16
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
17
  layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
18
 
19
- // from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
20
- #define NEAREST 0
21
- #define BILINEAR 1
22
- #define ALIGN_CORNERS (1 << 8)
23
-
24
- layout (constant_id = 0) const uint scale_mode = 0;
25
-
26
- float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
27
- const uint i00 = uint(i10 / p.sf0);
28
- const uint i01 = uint(i11 / p.sf1);
29
- const uint i02 = uint(i12 / p.sf2);
30
- const uint i03 = uint(i13 / p.sf3);
31
-
32
- return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
33
- }
34
-
35
- float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
36
- const uint i02 = uint(i12 / p.sf2);
37
- const uint i03 = uint(i13 / p.sf3);
38
- const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
39
-
40
- const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
41
- const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
42
- const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
43
- const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
44
-
45
- return
46
- v00 * (1.0-d.x) * (1.0-d.y) +
47
- v01 * d.x * (1.0-d.y) +
48
- v10 * (1.0-d.x) * d.y +
49
- v11 * d.x * d.y;
50
- }
51
-
52
- float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
53
- const ivec2 ne0 = ivec2(p.ne00, p.ne01);
54
-
55
- const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
56
- const vec2 c0f = floor(c);
57
- const vec2 d = c - c0f;
58
- const ivec2 c0 = max(ivec2(c0f), 0);
59
- const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
60
-
61
- return fetch_bilinear(c0, c1, d, i12, i13);
62
- }
63
-
64
- float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
65
- const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
66
- const vec2 c0f = floor(c);
67
- const vec2 d = c - c0f;
68
- const ivec2 c0 = ivec2(c0f);
69
- const ivec2 c1 = c0 + 1;
70
-
71
- return fetch_bilinear(c0, c1, d, i12, i13);
72
- }
73
-
74
  void main() {
75
  const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
76
 
@@ -83,18 +27,10 @@ void main() {
83
  const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
84
  const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
85
 
86
- float result;
87
- switch (scale_mode) {
88
- case NEAREST:
89
- result = fetch_nearest(i10, i11, i12, i13);
90
- break;
91
- case BILINEAR:
92
- result = interpolate_bilinear(i10, i11, i12, i13);
93
- break;
94
- case BILINEAR | ALIGN_CORNERS:
95
- result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
96
- break;
97
- }
98
 
99
- data_d[p.d_offset + idx] = D_TYPE(result);
100
  }
 
3
  layout (push_constant) uniform parameter
4
  {
5
  uint ne; uint a_offset; uint d_offset;
 
6
  uint nb00; uint nb01; uint nb02; uint nb03;
7
  uint ne10; uint ne11; uint ne12; uint ne13;
8
  float sf0; float sf1; float sf2; float sf3;
 
15
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
16
  layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  void main() {
19
  const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
20
 
 
27
  const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
28
  const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
29
 
30
+ const uint i00 = uint(i10 / p.sf0);
31
+ const uint i01 = uint(i11 / p.sf1);
32
+ const uint i02 = uint(i12 / p.sf2);
33
+ const uint i03 = uint(i13 / p.sf3);
 
 
 
 
 
 
 
 
34
 
35
+ data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
36
  }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -644,8 +644,6 @@ void process_shaders() {
644
  string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
645
  string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
646
 
647
- string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
648
-
649
  for (auto &c : compiles) {
650
  c.wait();
651
  }
 
644
  string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
645
  string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
646
 
 
 
647
  for (auto &c : compiles) {
648
  c.wait();
649
  }
ggml/src/ggml.c CHANGED
@@ -473,14 +473,6 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
473
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
474
  }
475
 
476
- const char * ggml_version(void) {
477
- return GGML_VERSION;
478
- }
479
-
480
- const char * ggml_commit(void) {
481
- return GGML_COMMIT;
482
- }
483
-
484
  //
485
  // timing
486
  //
@@ -3674,9 +3666,11 @@ static struct ggml_tensor * ggml_soft_max_impl(
3674
  if (mask) {
3675
  GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
3676
  GGML_ASSERT(ggml_is_contiguous(mask));
3677
- GGML_ASSERT(ggml_is_matrix(mask));
3678
  GGML_ASSERT(mask->ne[0] == a->ne[0]);
3679
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
 
 
3680
  }
3681
 
3682
  if (max_bias > 0.0f) {
@@ -4697,13 +4691,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
4697
  GGML_ASSERT(ggml_can_mul_mat(k, q));
4698
  // TODO: check if vT can be multiplied by (k*qT)
4699
 
 
 
 
4700
  if (mask) {
4701
  GGML_ASSERT(ggml_is_contiguous(mask));
4702
- GGML_ASSERT(mask->ne[2] == 1);
4703
- GGML_ASSERT(mask->ne[3] == 1);
4704
  GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
4705
  "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
4706
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
 
 
4707
  }
4708
 
4709
  if (max_bias > 0.0f) {
 
473
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
474
  }
475
 
 
 
 
 
 
 
 
 
476
  //
477
  // timing
478
  //
 
3666
  if (mask) {
3667
  GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
3668
  GGML_ASSERT(ggml_is_contiguous(mask));
3669
+ GGML_ASSERT(ggml_is_3d(mask));
3670
  GGML_ASSERT(mask->ne[0] == a->ne[0]);
3671
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
3672
+ GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
3673
+ GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
3674
  }
3675
 
3676
  if (max_bias > 0.0f) {
 
4691
  GGML_ASSERT(ggml_can_mul_mat(k, q));
4692
  // TODO: check if vT can be multiplied by (k*qT)
4693
 
4694
+ GGML_ASSERT(q->ne[3] == k->ne[3]);
4695
+ GGML_ASSERT(q->ne[3] == v->ne[3]);
4696
+
4697
  if (mask) {
4698
  GGML_ASSERT(ggml_is_contiguous(mask));
4699
+ GGML_ASSERT(mask->ne[2] == q->ne[3]);
 
4700
  GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
4701
  "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
4702
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
4703
+
4704
+ GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
4705
  }
4706
 
4707
  if (max_bias > 0.0f) {