jeffbolznv commited on
Commit
80a188c
·
1 Parent(s): ad199b1

vulkan: Optimize argsort (llama/15354)

Browse files

- Launch an appropriate number of invocations (next larger power of two).
32 invocations is common and the barrier is much cheaper there.
- Specialize for "needs bounds checking" vs not.
- Make the code less branchy and [[unroll]] the loops. In the final code,
I see no branches inside the main loop (only predicated stores) when
needs_bounds_check is false.
- Always sort ascending, then apply the ascending vs descending option when
doing the final stores to memory.
- Copy the values into shared memory, makes them slightly cheaper to access.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -345,6 +345,9 @@ enum vk_conv_shapes {
345
  CONV_SHAPE_COUNT,
346
  };
347
 
 
 
 
348
  struct vk_device_struct {
349
  std::recursive_mutex mutex;
350
 
@@ -505,7 +508,7 @@ struct vk_device_struct {
505
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
506
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
507
  vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
508
- vk_pipeline pipeline_argsort_f32;
509
  vk_pipeline pipeline_sum_rows_f32;
510
  vk_pipeline pipeline_argmax_f32;
511
  vk_pipeline pipeline_count_equal_i32;
@@ -870,7 +873,6 @@ struct vk_op_soft_max_push_constants {
870
 
871
  struct vk_op_argsort_push_constants {
872
  uint32_t ncols;
873
- uint32_t ncols_pad;
874
  int32_t order;
875
  };
876
 
@@ -3099,7 +3101,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
3099
  ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3100
  }
3101
 
3102
- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
 
 
3103
 
3104
  ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3105
 
@@ -7160,7 +7164,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7160
  }
7161
  case GGML_OP_ARGSORT:
7162
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
7163
- return ctx->device->pipeline_argsort_f32;
 
7164
  }
7165
  return nullptr;
7166
  case GGML_OP_SUM:
@@ -8485,16 +8490,8 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
8485
 
8486
  uint32_t ncols = src0->ne[0];
8487
 
8488
- uint32_t ncols_pad = 1;
8489
- while (ncols_pad < ncols) {
8490
- ncols_pad *= 2;
8491
- }
8492
-
8493
- GGML_ASSERT(ncols_pad <= 1024);
8494
-
8495
  ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
8496
  ncols,
8497
- ncols_pad,
8498
  op_params[0],
8499
  }, dryrun);
8500
  }
@@ -11367,6 +11364,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11367
  case GGML_OP_OPT_STEP_ADAMW:
11368
  case GGML_OP_OPT_STEP_SGD:
11369
  return op->src[0]->type == GGML_TYPE_F32;
 
 
11370
  case GGML_OP_UPSCALE:
11371
  case GGML_OP_ACC:
11372
  case GGML_OP_CONCAT:
@@ -11376,7 +11375,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11376
  case GGML_OP_DIAG_MASK_INF:
11377
  case GGML_OP_SOFT_MAX:
11378
  case GGML_OP_SOFT_MAX_BACK:
11379
- case GGML_OP_ARGSORT:
11380
  case GGML_OP_SUM:
11381
  case GGML_OP_SUM_ROWS:
11382
  case GGML_OP_ARGMAX:
 
345
  CONV_SHAPE_COUNT,
346
  };
347
 
348
+ static constexpr uint32_t num_argsort_pipelines = 11;
349
+ static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
350
+
351
  struct vk_device_struct {
352
  std::recursive_mutex mutex;
353
 
 
508
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
509
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
510
  vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
511
+ vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
512
  vk_pipeline pipeline_sum_rows_f32;
513
  vk_pipeline pipeline_argmax_f32;
514
  vk_pipeline pipeline_count_equal_i32;
 
873
 
874
  struct vk_op_argsort_push_constants {
875
  uint32_t ncols;
 
876
  int32_t order;
877
  };
878
 
 
3101
  ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3102
  }
3103
 
3104
+ for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3105
+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3106
+ }
3107
 
3108
  ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3109
 
 
7164
  }
7165
  case GGML_OP_ARGSORT:
7166
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
7167
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
7168
+ return ctx->device->pipeline_argsort_f32[idx];
7169
  }
7170
  return nullptr;
7171
  case GGML_OP_SUM:
 
8490
 
8491
  uint32_t ncols = src0->ne[0];
8492
 
 
 
 
 
 
 
 
8493
  ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
8494
  ncols,
 
8495
  op_params[0],
8496
  }, dryrun);
8497
  }
 
11364
  case GGML_OP_OPT_STEP_ADAMW:
11365
  case GGML_OP_OPT_STEP_SGD:
11366
  return op->src[0]->type == GGML_TYPE_F32;
11367
+ case GGML_OP_ARGSORT:
11368
+ return op->ne[0] <= max_argsort_cols;
11369
  case GGML_OP_UPSCALE:
11370
  case GGML_OP_ACC:
11371
  case GGML_OP_CONCAT:
 
11375
  case GGML_OP_DIAG_MASK_INF:
11376
  case GGML_OP_SOFT_MAX:
11377
  case GGML_OP_SOFT_MAX_BACK:
 
11378
  case GGML_OP_SUM:
11379
  case GGML_OP_SUM_ROWS:
11380
  case GGML_OP_ARGMAX:
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp CHANGED
@@ -1,22 +1,24 @@
1
  #version 450
 
2
 
3
  #include "types.comp"
4
 
5
- #define BLOCK_SIZE 1024
 
6
  #define ASC 0
7
 
8
- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
9
 
10
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
11
  layout (binding = 1) buffer D {int data_d[];};
12
 
13
  layout (push_constant) uniform parameter {
14
  uint ncols;
15
- uint ncols_pad;
16
  uint order;
17
  } p;
18
 
19
  shared int dst_row[BLOCK_SIZE];
 
20
 
21
  void swap(uint idx0, uint idx1) {
22
  int tmp = dst_row[idx0];
@@ -24,7 +26,7 @@ void swap(uint idx0, uint idx1) {
24
  dst_row[idx1] = tmp;
25
  }
26
 
27
- void main() {
28
  // bitonic sort
29
  const int col = int(gl_LocalInvocationID.x);
30
  const uint row = gl_WorkGroupID.y;
@@ -32,38 +34,46 @@ void main() {
32
  const uint row_offset = row * p.ncols;
33
 
34
  // initialize indices
35
- if (col < p.ncols_pad) {
36
- dst_row[col] = col;
37
- }
38
  barrier();
39
 
40
- for (uint k = 2; k <= p.ncols_pad; k *= 2) {
41
- for (uint j = k / 2; j > 0; j /= 2) {
42
- const uint ixj = col ^ j;
43
- if (col < p.ncols_pad && ixj > col) {
44
- if ((col & k) == 0) {
45
- if (dst_row[col] >= p.ncols ||
46
- (dst_row[ixj] < p.ncols && (p.order == ASC ?
47
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
48
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
49
- ) {
50
- swap(col, ixj);
51
- }
52
- } else {
53
- if (dst_row[ixj] >= p.ncols ||
54
- (dst_row[col] < p.ncols && (p.order == ASC ?
55
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
56
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
57
- ) {
58
- swap(col, ixj);
59
- }
60
- }
61
  }
 
62
  barrier();
63
  }
64
  }
65
 
66
  if (col < p.ncols) {
67
- data_d[row_offset + col] = dst_row[col];
 
 
 
 
 
 
 
 
 
 
 
 
68
  }
69
  }
 
1
  #version 450
2
+ #extension GL_EXT_control_flow_attributes : enable
3
 
4
  #include "types.comp"
5
 
6
+ layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7
+ layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
8
  #define ASC 0
9
 
10
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11
 
12
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13
  layout (binding = 1) buffer D {int data_d[];};
14
 
15
  layout (push_constant) uniform parameter {
16
  uint ncols;
 
17
  uint order;
18
  } p;
19
 
20
  shared int dst_row[BLOCK_SIZE];
21
+ shared A_TYPE a_sh[BLOCK_SIZE];
22
 
23
  void swap(uint idx0, uint idx1) {
24
  int tmp = dst_row[idx0];
 
26
  dst_row[idx1] = tmp;
27
  }
28
 
29
+ void argsort(bool needs_bounds_check) {
30
  // bitonic sort
31
  const int col = int(gl_LocalInvocationID.x);
32
  const uint row = gl_WorkGroupID.y;
 
34
  const uint row_offset = row * p.ncols;
35
 
36
  // initialize indices
37
+ dst_row[col] = col;
38
+ a_sh[col] = data_a[row_offset + col];
 
39
  barrier();
40
 
41
+ uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
42
+ [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
43
+ uint num_inner_loop_iters = outer_idx + 1;
44
+ [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
45
+ const int ixj = int(col ^ j);
46
+
47
+ int idx_0 = (col & k) == 0 ? col : ixj;
48
+ int idx_1 = (col & k) == 0 ? ixj : col;
49
+
50
+ int sh_idx_0 = dst_row[idx_0];
51
+ int sh_idx_1 = dst_row[idx_1];
52
+ bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
53
+ bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
54
+
55
+ if ((idx_0_oob ||
56
+ (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
57
+ swap(idx_0, idx_1);
 
 
 
 
58
  }
59
+
60
  barrier();
61
  }
62
  }
63
 
64
  if (col < p.ncols) {
65
+ if (p.order == ASC) {
66
+ data_d[row_offset + col] = dst_row[col];
67
+ } else {
68
+ data_d[row_offset + p.ncols - col - 1] = dst_row[col];
69
+ }
70
+ }
71
+ }
72
+
73
+ void main() {
74
+ if (p.ncols == BLOCK_SIZE) {
75
+ argsort(false);
76
+ } else {
77
+ argsort(true);
78
  }
79
  }