mollysama Layl Bongers compilade ggerganov commited on
Commit
bd4f5ec
·
1 Parent(s): e3e9ca4

llama : support RWKV v6 models (llama/8980)

Browse files

* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <[email protected]>

* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <[email protected]>

* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <[email protected]>

* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <[email protected]>

* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <[email protected]>

* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <[email protected]>

* RWKV v6 graph building

Signed-off-by: Molly Sophia <[email protected]>

* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <[email protected]>

* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <[email protected]>

* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <[email protected]>

* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <[email protected]>

* Remove trailing whitespaces

Signed-off-by: Molly Sophia <[email protected]>

* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <[email protected]>

* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <[email protected]>

* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <[email protected]>

* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <[email protected]>

* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <[email protected]>

* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <[email protected]>

* Update src/llama.cpp

Co-authored-by: compilade <[email protected]>

* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <[email protected]>

* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <[email protected]>

* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <[email protected]>

* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <[email protected]>

* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <[email protected]>

* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <[email protected]>

* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <[email protected]>

---------

Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: Layl Bongers <[email protected]>
Co-authored-by: compilade <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (2) hide show
  1. ggml/include/ggml.h +19 -0
  2. ggml/src/ggml.c +225 -3
ggml/include/ggml.h CHANGED
@@ -514,6 +514,7 @@ extern "C" {
514
  GGML_OP_WIN_UNPART,
515
  GGML_OP_GET_REL_POS,
516
  GGML_OP_ADD_REL_POS,
 
517
 
518
  GGML_OP_UNARY,
519
 
@@ -548,6 +549,7 @@ extern "C" {
548
  GGML_UNARY_OP_SILU,
549
  GGML_UNARY_OP_HARDSWISH,
550
  GGML_UNARY_OP_HARDSIGMOID,
 
551
 
552
  GGML_UNARY_OP_COUNT,
553
  };
@@ -1165,6 +1167,14 @@ extern "C" {
1165
  struct ggml_context * ctx,
1166
  struct ggml_tensor * a);
1167
 
 
 
 
 
 
 
 
 
1168
  // normalize along rows
1169
  GGML_API struct ggml_tensor * ggml_norm(
1170
  struct ggml_context * ctx,
@@ -1913,6 +1923,15 @@ extern "C" {
1913
  struct ggml_tensor * pw,
1914
  struct ggml_tensor * ph);
1915
 
 
 
 
 
 
 
 
 
 
1916
  // custom operators
1917
 
1918
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
 
514
  GGML_OP_WIN_UNPART,
515
  GGML_OP_GET_REL_POS,
516
  GGML_OP_ADD_REL_POS,
517
+ GGML_OP_RWKV_WKV,
518
 
519
  GGML_OP_UNARY,
520
 
 
549
  GGML_UNARY_OP_SILU,
550
  GGML_UNARY_OP_HARDSWISH,
551
  GGML_UNARY_OP_HARDSIGMOID,
552
+ GGML_UNARY_OP_EXP,
553
 
554
  GGML_UNARY_OP_COUNT,
555
  };
 
1167
  struct ggml_context * ctx,
1168
  struct ggml_tensor * a);
1169
 
1170
+ GGML_API struct ggml_tensor * ggml_exp(
1171
+ struct ggml_context * ctx,
1172
+ struct ggml_tensor * a);
1173
+
1174
+ GGML_API struct ggml_tensor * ggml_exp_inplace(
1175
+ struct ggml_context * ctx,
1176
+ struct ggml_tensor * a);
1177
+
1178
  // normalize along rows
1179
  GGML_API struct ggml_tensor * ggml_norm(
1180
  struct ggml_context * ctx,
 
1923
  struct ggml_tensor * pw,
1924
  struct ggml_tensor * ph);
1925
 
1926
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1927
+ struct ggml_context * ctx,
1928
+ struct ggml_tensor * k,
1929
+ struct ggml_tensor * v,
1930
+ struct ggml_tensor * r,
1931
+ struct ggml_tensor * tf,
1932
+ struct ggml_tensor * td,
1933
+ struct ggml_tensor * state);
1934
+
1935
  // custom operators
1936
 
1937
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
ggml/src/ggml.c CHANGED
@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
2422
  // TODO: optimize performance
2423
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2424
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 
2425
 
2426
  static const float GELU_COEF_A = 0.044715f;
2427
  static const float GELU_QUICK_COEF = -1.702f;
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2932
  "WIN_UNPART",
2933
  "GET_REL_POS",
2934
  "ADD_REL_POS",
 
2935
 
2936
  "UNARY",
2937
 
@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2950
  "CROSS_ENTROPY_LOSS_BACK",
2951
  };
2952
 
2953
- static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2954
 
2955
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2956
  "none",
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3024
  "win_unpart(x)",
3025
  "get_rel_pos(x)",
3026
  "add_rel_pos(x)",
 
3027
 
3028
  "unary(x)",
3029
 
@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3042
  "cross_entropy_loss_back(x,y)",
3043
  };
3044
 
3045
- static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
3046
 
3047
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3048
 
@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
3061
  "SILU",
3062
  "HARDSWISH",
3063
  "HARDSIGMOID",
 
3064
  };
3065
 
3066
- static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
3067
 
3068
 
3069
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5466,6 +5470,19 @@ struct ggml_tensor * ggml_hardsigmoid(
5466
  return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
5467
  }
5468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5469
  // ggml_norm
5470
 
5471
  static struct ggml_tensor * ggml_norm_impl(
@@ -7734,6 +7751,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
7734
  return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
7735
  }
7736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7737
  // ggml_unary
7738
 
7739
  static struct ggml_tensor * ggml_unary_impl(
@@ -12114,6 +12184,48 @@ static void ggml_compute_forward_hardsigmoid(
12114
  }
12115
  }
12116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12117
 
12118
  // ggml_compute_forward_norm
12119
 
@@ -16692,6 +16804,10 @@ static void ggml_compute_forward_unary(
16692
  {
16693
  ggml_compute_forward_hardsigmoid(params, dst);
16694
  } break;
 
 
 
 
16695
  default:
16696
  {
16697
  GGML_ABORT("fatal error");
@@ -16827,6 +16943,96 @@ static void ggml_compute_forward_add_rel_pos(
16827
  }
16828
  }
16829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16830
  // ggml_compute_forward_map_unary
16831
 
16832
  static void ggml_compute_forward_map_unary_f32(
@@ -17478,6 +17684,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17478
  {
17479
  ggml_compute_forward_add_rel_pos(params, tensor);
17480
  } break;
 
 
 
 
17481
  case GGML_OP_MAP_UNARY:
17482
  {
17483
  ggml_unary_op_f32_t fun;
@@ -18591,12 +18801,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18591
  zero_table);
18592
  }
18593
  } break;
 
 
 
 
 
 
 
 
 
18594
  default:
18595
  GGML_ABORT("fatal error");
18596
  }
18597
  } break;
18598
  case GGML_OP_GET_REL_POS:
18599
  case GGML_OP_ADD_REL_POS:
 
18600
  case GGML_OP_MAP_UNARY:
18601
  case GGML_OP_MAP_BINARY:
18602
  case GGML_OP_MAP_CUSTOM1_F32:
@@ -19021,6 +19241,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
19021
  case GGML_UNARY_OP_SIGMOID:
19022
  case GGML_UNARY_OP_HARDSWISH:
19023
  case GGML_UNARY_OP_HARDSIGMOID:
 
19024
  {
19025
  n_tasks = 1;
19026
  } break;
@@ -19112,6 +19333,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
19112
  case GGML_OP_WIN_PART:
19113
  case GGML_OP_WIN_UNPART:
19114
  case GGML_OP_GET_REL_POS:
 
19115
  case GGML_OP_MAP_UNARY:
19116
  case GGML_OP_MAP_BINARY:
19117
  case GGML_OP_MAP_CUSTOM1_F32:
 
2422
  // TODO: optimize performance
2423
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2424
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2425
+ inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
2426
 
2427
  static const float GELU_COEF_A = 0.044715f;
2428
  static const float GELU_QUICK_COEF = -1.702f;
 
2933
  "WIN_UNPART",
2934
  "GET_REL_POS",
2935
  "ADD_REL_POS",
2936
+ "RWKV_WKV",
2937
 
2938
  "UNARY",
2939
 
 
2952
  "CROSS_ENTROPY_LOSS_BACK",
2953
  };
2954
 
2955
+ static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
2956
 
2957
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2958
  "none",
 
3026
  "win_unpart(x)",
3027
  "get_rel_pos(x)",
3028
  "add_rel_pos(x)",
3029
+ "rwkv_wkv(k, v, r, tf, td, s)",
3030
 
3031
  "unary(x)",
3032
 
 
3045
  "cross_entropy_loss_back(x,y)",
3046
  };
3047
 
3048
+ static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
3049
 
3050
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3051
 
 
3064
  "SILU",
3065
  "HARDSWISH",
3066
  "HARDSIGMOID",
3067
+ "EXP",
3068
  };
3069
 
3070
+ static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
3071
 
3072
 
3073
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 
5470
  return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
5471
  }
5472
 
5473
+ // ggml exp
5474
+ struct ggml_tensor * ggml_exp(
5475
+ struct ggml_context * ctx,
5476
+ struct ggml_tensor * a) {
5477
+ return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
5478
+ }
5479
+
5480
+ struct ggml_tensor * ggml_exp_inplace(
5481
+ struct ggml_context * ctx,
5482
+ struct ggml_tensor * a) {
5483
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
5484
+ }
5485
+
5486
  // ggml_norm
5487
 
5488
  static struct ggml_tensor * ggml_norm_impl(
 
7751
  return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
7752
  }
7753
 
7754
+ // ggml_rwkv_wkv
7755
+
7756
+ struct ggml_tensor * ggml_rwkv_wkv(
7757
+ struct ggml_context * ctx,
7758
+ struct ggml_tensor * k,
7759
+ struct ggml_tensor * v,
7760
+ struct ggml_tensor * r,
7761
+ struct ggml_tensor * tf,
7762
+ struct ggml_tensor * td,
7763
+ struct ggml_tensor * state) {
7764
+ GGML_ASSERT(ggml_is_contiguous(k));
7765
+ GGML_ASSERT(ggml_is_contiguous(v));
7766
+ GGML_ASSERT(ggml_is_contiguous(r));
7767
+ GGML_ASSERT(ggml_is_contiguous(tf));
7768
+ GGML_ASSERT(ggml_is_contiguous(td));
7769
+ GGML_ASSERT(ggml_is_contiguous(state));
7770
+
7771
+ const int64_t S = k->ne[0];
7772
+ const int64_t H = k->ne[2];
7773
+ const int64_t n_tokens = k->ne[3];
7774
+ const int64_t n_seqs = state->ne[1];
7775
+ {
7776
+ GGML_ASSERT(k->ne[1] == 1);
7777
+ GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7778
+ GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7779
+ // TODO: RWKV v4 and v5
7780
+ GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7781
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
7782
+ }
7783
+
7784
+ bool is_node = false;
7785
+
7786
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7787
+ GGML_ABORT("fatal error"); // TODO: implement backward
7788
+ is_node = true;
7789
+ }
7790
+
7791
+ // concat output and new_state
7792
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7793
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7794
+
7795
+ result->op = GGML_OP_RWKV_WKV;
7796
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7797
+ result->src[0] = k;
7798
+ result->src[1] = v;
7799
+ result->src[2] = r;
7800
+ result->src[3] = tf;
7801
+ result->src[4] = td;
7802
+ result->src[5] = state;
7803
+
7804
+ return result;
7805
+ }
7806
+
7807
  // ggml_unary
7808
 
7809
  static struct ggml_tensor * ggml_unary_impl(
 
12184
  }
12185
  }
12186
 
12187
+ static void ggml_compute_forward_exp_f32(
12188
+ const struct ggml_compute_params * params,
12189
+ struct ggml_tensor * dst) {
12190
+
12191
+ const struct ggml_tensor * src0 = dst->src[0];
12192
+
12193
+ if (params->ith != 0) {
12194
+ return;
12195
+ }
12196
+
12197
+ assert(ggml_is_contiguous_1(src0));
12198
+ assert(ggml_is_contiguous_1(dst));
12199
+ assert(ggml_are_same_shape(src0, dst));
12200
+
12201
+ const int n = ggml_nrows(src0);
12202
+ const int nc = src0->ne[0];
12203
+
12204
+ for (int i = 0; i < n; i++) {
12205
+ ggml_vec_exp_f32(nc,
12206
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
12207
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
12208
+ }
12209
+ }
12210
+
12211
+ static void ggml_compute_forward_exp(
12212
+ const struct ggml_compute_params * params,
12213
+ struct ggml_tensor * dst) {
12214
+
12215
+ const struct ggml_tensor * src0 = dst->src[0];
12216
+
12217
+ switch (src0->type) {
12218
+ case GGML_TYPE_F32:
12219
+ {
12220
+ ggml_compute_forward_exp_f32(params, dst);
12221
+ } break;
12222
+ default:
12223
+ {
12224
+ GGML_ABORT("fatal error");
12225
+ }
12226
+ }
12227
+ }
12228
+
12229
 
12230
  // ggml_compute_forward_norm
12231
 
 
16804
  {
16805
  ggml_compute_forward_hardsigmoid(params, dst);
16806
  } break;
16807
+ case GGML_UNARY_OP_EXP:
16808
+ {
16809
+ ggml_compute_forward_exp(params, dst);
16810
+ } break;
16811
  default:
16812
  {
16813
  GGML_ABORT("fatal error");
 
16943
  }
16944
  }
16945
 
16946
+ // ggml_compute_forward_rwkv_wkv
16947
+
16948
+ static void ggml_compute_forward_rwkv_wkv_f32(
16949
+ const struct ggml_compute_params * params,
16950
+ struct ggml_tensor * dst) {
16951
+ const size_t T = dst->src[1]->ne[3];
16952
+ const size_t C = dst->ne[0];
16953
+ const size_t H = dst->src[1]->ne[2];
16954
+ const size_t n_seqs = dst->src[5]->ne[1];
16955
+
16956
+ float * dst_data = (float *) dst->data;
16957
+ float * state = ((float *) dst->data) + C * T;
16958
+
16959
+ if (params->ith != 0) {
16960
+ return;
16961
+ }
16962
+
16963
+ memset(dst_data, 0, T * C * sizeof(float));
16964
+
16965
+ float * k = (float *) dst->src[0]->data;
16966
+ float * v = (float *) dst->src[1]->data;
16967
+ float * r = (float *) dst->src[2]->data;
16968
+ float * time_faaaa = (float *) dst->src[3]->data;
16969
+ float * time_decay = (float *) dst->src[4]->data;
16970
+
16971
+ size_t t_stride = H * (C / H);
16972
+
16973
+ size_t h_stride = C / H;
16974
+ size_t h_stride_2d = (C / H) * (C / H);
16975
+
16976
+ // basically fused operations:
16977
+ // dst = r @ (time_faaaa * (k @ v) + state),
16978
+ // state = time_decay * state + (k @ v),
16979
+ // recursive through each token
16980
+ for (size_t t = 0; t < T; t++) {
16981
+ size_t t_offset = t * t_stride;
16982
+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
16983
+ float * state_cur = state + state_offset;
16984
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16985
+
16986
+ for (size_t h = 0; h < H; h++) {
16987
+ size_t h_offset = h * h_stride;
16988
+ size_t t_h_offset = t_offset + h_offset;
16989
+ size_t h_2d_offset = h * h_stride_2d;
16990
+
16991
+ for (size_t i = 0; i < C / H; i++) {
16992
+ size_t t_h_i_offset = t_h_offset + i;
16993
+ size_t h_i_offset = h_offset + i;
16994
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16995
+
16996
+ float k_val = k[t_h_i_offset];
16997
+ float r_val = r[t_h_i_offset];
16998
+ float time_faaaa_val = time_faaaa[h_i_offset];
16999
+ // RWKV v6: different time_decay for each token.
17000
+ float time_decay_val = time_decay[t_h_i_offset];
17001
+
17002
+ for (size_t j = 0; j < C / H; j ++) {
17003
+ size_t t_h_j_offset = t_h_offset + j;
17004
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
17005
+
17006
+ float v_val = v[t_h_j_offset];
17007
+ float kv_val = v_val * k_val;
17008
+ float prev_state_val = state_prev[h_2d_i_j_offset];
17009
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
17010
+ dst_data[t_h_j_offset] += temp_val * r_val;
17011
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17012
+ }
17013
+ }
17014
+ }
17015
+ }
17016
+ }
17017
+
17018
+ static void ggml_compute_forward_rwkv_wkv(
17019
+ const struct ggml_compute_params * params,
17020
+ struct ggml_tensor * dst) {
17021
+
17022
+ const struct ggml_tensor * src0 = dst->src[0];
17023
+
17024
+ switch (src0->type) {
17025
+ case GGML_TYPE_F32:
17026
+ {
17027
+ ggml_compute_forward_rwkv_wkv_f32(params, dst);
17028
+ } break;
17029
+ default:
17030
+ {
17031
+ GGML_ABORT("fatal error");
17032
+ }
17033
+ }
17034
+ }
17035
+
17036
  // ggml_compute_forward_map_unary
17037
 
17038
  static void ggml_compute_forward_map_unary_f32(
 
17684
  {
17685
  ggml_compute_forward_add_rel_pos(params, tensor);
17686
  } break;
17687
+ case GGML_OP_RWKV_WKV:
17688
+ {
17689
+ ggml_compute_forward_rwkv_wkv(params, tensor);
17690
+ } break;
17691
  case GGML_OP_MAP_UNARY:
17692
  {
17693
  ggml_unary_op_f32_t fun;
 
18801
  zero_table);
18802
  }
18803
  } break;
18804
+ case GGML_UNARY_OP_EXP:
18805
+ {
18806
+ if (src0->grad) {
18807
+ src0->grad = ggml_add_or_set(ctx,
18808
+ src0->grad,
18809
+ ggml_mul(ctx, tensor, tensor->grad),
18810
+ zero_table);
18811
+ }
18812
+ } break;
18813
  default:
18814
  GGML_ABORT("fatal error");
18815
  }
18816
  } break;
18817
  case GGML_OP_GET_REL_POS:
18818
  case GGML_OP_ADD_REL_POS:
18819
+ case GGML_OP_RWKV_WKV:
18820
  case GGML_OP_MAP_UNARY:
18821
  case GGML_OP_MAP_BINARY:
18822
  case GGML_OP_MAP_CUSTOM1_F32:
 
19241
  case GGML_UNARY_OP_SIGMOID:
19242
  case GGML_UNARY_OP_HARDSWISH:
19243
  case GGML_UNARY_OP_HARDSIGMOID:
19244
+ case GGML_UNARY_OP_EXP:
19245
  {
19246
  n_tasks = 1;
19247
  } break;
 
19333
  case GGML_OP_WIN_PART:
19334
  case GGML_OP_WIN_UNPART:
19335
  case GGML_OP_GET_REL_POS:
19336
+ case GGML_OP_RWKV_WKV:
19337
  case GGML_OP_MAP_UNARY:
19338
  case GGML_OP_MAP_BINARY:
19339
  case GGML_OP_MAP_CUSTOM1_F32: