Spaces:
Running
llama : support RWKV v6 models (llama/8980)
Browse files* convert_hf_to_gguf: Add support for RWKV v6
Signed-off-by: Molly Sophia <[email protected]>
* Add RWKV tokenization
* Fix build
Signed-off-by: Molly Sophia <[email protected]>
* Do not use special tokens when matching in RWKV tokenizer
* Fix model loading
* Add (broken) placeholder graph builder for RWKV
* Add workaround for kv cache
* Add logits conversion to rwkv5
* Add rwkv5 layer norms
* Add time mix KVRG & correct merge mistake
* Add remaining time mix parameters
* Add time mix output loading
* Add placeholder llm_build_time_mix
* Fix build
Signed-off-by: Molly Sophia <[email protected]>
* Load more tensors for rwkv v6
Signed-off-by: Molly Sophia <[email protected]>
* Fix rwkv tokenizer
Signed-off-by: Molly Sophia <[email protected]>
* ggml: Add unary operator Exp
Signed-off-by: Molly Sophia <[email protected]>
* RWKV v6 graph building
Signed-off-by: Molly Sophia <[email protected]>
* Add ``rescale_every_n_layers`` parameter
Signed-off-by: Molly Sophia <[email protected]>
* Add ``wkv.head_size`` key for RWKV
so it doesn't reuse Mamba ssm parameters
Signed-off-by: Molly Sophia <[email protected]>
* Fix offloading layers to CUDA
Signed-off-by: Molly Sophia <[email protected]>
* Fix parallel inferencing for RWKV
Signed-off-by: Molly Sophia <[email protected]>
* Remove trailing whitespaces
Signed-off-by: Molly Sophia <[email protected]>
* build_rwkv: Avoid using inplace operations
Signed-off-by: Molly Sophia <[email protected]>
* convert_hf_to_gguf: rwkv: Avoid using ``eval``
Signed-off-by: Molly Sophia <[email protected]>
* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually
Signed-off-by: Molly Sophia <[email protected]>
* Update convert_hf_to_gguf.py
Co-authored-by: compilade <[email protected]>
* ggml: Add backward computation for unary op ``exp``
Signed-off-by: Molly Sophia <[email protected]>
* Update convert_hf_to_gguf.py
Co-authored-by: compilade <[email protected]>
* Update convert_hf_to_gguf.py
Co-authored-by: compilade <[email protected]>
* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV
Signed-off-by: Molly Sophia <[email protected]>
* build_rwkv6: Simplify graph
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Detect model.type
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Fix tensor loading for 7B/14B models
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Fix group_norm assertion failure with Metal
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Clean up
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Add quantization tensor exclusion
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Use the new advanced batch splits
Signed-off-by: Molly Sophia <[email protected]>
* Update src/llama.cpp
Co-authored-by: compilade <[email protected]>
* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``
Co-authored-by: compilade <[email protected]>
* llama: rwkv6: Apply code style and misc changes
Signed-off-by: Molly Sophia <[email protected]>
* converter: Use class name ``Rwkv6Model``
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Make use of key ``feed_forward_length``
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``
Signed-off-by: Molly Sophia <[email protected]>
* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Remove unused nodes
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Apply code format changes
Signed-off-by: Molly Sophia <[email protected]>
* llama: rwkv6: Add lora for some supported tensors
Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight
Signed-off-by: Molly Sophia <[email protected]>
* rwkv : speed-up tokenization using trie
* minor : style + indentation
* llama: rwkv6: Avoid division by zero
Co-authored-by: compilade <[email protected]>
* ggml: rwkv_wkv: Avoid copying the state
Signed-off-by: Molly Sophia <[email protected]>
---------
Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: Layl Bongers <[email protected]>
Co-authored-by: compilade <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/include/ggml.h +19 -0
- ggml/src/ggml.c +225 -3
|
@@ -514,6 +514,7 @@ extern "C" {
|
|
| 514 |
GGML_OP_WIN_UNPART,
|
| 515 |
GGML_OP_GET_REL_POS,
|
| 516 |
GGML_OP_ADD_REL_POS,
|
|
|
|
| 517 |
|
| 518 |
GGML_OP_UNARY,
|
| 519 |
|
|
@@ -548,6 +549,7 @@ extern "C" {
|
|
| 548 |
GGML_UNARY_OP_SILU,
|
| 549 |
GGML_UNARY_OP_HARDSWISH,
|
| 550 |
GGML_UNARY_OP_HARDSIGMOID,
|
|
|
|
| 551 |
|
| 552 |
GGML_UNARY_OP_COUNT,
|
| 553 |
};
|
|
@@ -1165,6 +1167,14 @@ extern "C" {
|
|
| 1165 |
struct ggml_context * ctx,
|
| 1166 |
struct ggml_tensor * a);
|
| 1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1168 |
// normalize along rows
|
| 1169 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1170 |
struct ggml_context * ctx,
|
|
@@ -1913,6 +1923,15 @@ extern "C" {
|
|
| 1913 |
struct ggml_tensor * pw,
|
| 1914 |
struct ggml_tensor * ph);
|
| 1915 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1916 |
// custom operators
|
| 1917 |
|
| 1918 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
|
|
| 514 |
GGML_OP_WIN_UNPART,
|
| 515 |
GGML_OP_GET_REL_POS,
|
| 516 |
GGML_OP_ADD_REL_POS,
|
| 517 |
+
GGML_OP_RWKV_WKV,
|
| 518 |
|
| 519 |
GGML_OP_UNARY,
|
| 520 |
|
|
|
|
| 549 |
GGML_UNARY_OP_SILU,
|
| 550 |
GGML_UNARY_OP_HARDSWISH,
|
| 551 |
GGML_UNARY_OP_HARDSIGMOID,
|
| 552 |
+
GGML_UNARY_OP_EXP,
|
| 553 |
|
| 554 |
GGML_UNARY_OP_COUNT,
|
| 555 |
};
|
|
|
|
| 1167 |
struct ggml_context * ctx,
|
| 1168 |
struct ggml_tensor * a);
|
| 1169 |
|
| 1170 |
+
GGML_API struct ggml_tensor * ggml_exp(
|
| 1171 |
+
struct ggml_context * ctx,
|
| 1172 |
+
struct ggml_tensor * a);
|
| 1173 |
+
|
| 1174 |
+
GGML_API struct ggml_tensor * ggml_exp_inplace(
|
| 1175 |
+
struct ggml_context * ctx,
|
| 1176 |
+
struct ggml_tensor * a);
|
| 1177 |
+
|
| 1178 |
// normalize along rows
|
| 1179 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1180 |
struct ggml_context * ctx,
|
|
|
|
| 1923 |
struct ggml_tensor * pw,
|
| 1924 |
struct ggml_tensor * ph);
|
| 1925 |
|
| 1926 |
+
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
|
| 1927 |
+
struct ggml_context * ctx,
|
| 1928 |
+
struct ggml_tensor * k,
|
| 1929 |
+
struct ggml_tensor * v,
|
| 1930 |
+
struct ggml_tensor * r,
|
| 1931 |
+
struct ggml_tensor * tf,
|
| 1932 |
+
struct ggml_tensor * td,
|
| 1933 |
+
struct ggml_tensor * state);
|
| 1934 |
+
|
| 1935 |
// custom operators
|
| 1936 |
|
| 1937 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
|
|
| 2422 |
// TODO: optimize performance
|
| 2423 |
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
| 2424 |
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
|
|
|
| 2425 |
|
| 2426 |
static const float GELU_COEF_A = 0.044715f;
|
| 2427 |
static const float GELU_QUICK_COEF = -1.702f;
|
|
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2932 |
"WIN_UNPART",
|
| 2933 |
"GET_REL_POS",
|
| 2934 |
"ADD_REL_POS",
|
|
|
|
| 2935 |
|
| 2936 |
"UNARY",
|
| 2937 |
|
|
@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2950 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2951 |
};
|
| 2952 |
|
| 2953 |
-
static_assert(GGML_OP_COUNT ==
|
| 2954 |
|
| 2955 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2956 |
"none",
|
|
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 3024 |
"win_unpart(x)",
|
| 3025 |
"get_rel_pos(x)",
|
| 3026 |
"add_rel_pos(x)",
|
|
|
|
| 3027 |
|
| 3028 |
"unary(x)",
|
| 3029 |
|
|
@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 3042 |
"cross_entropy_loss_back(x,y)",
|
| 3043 |
};
|
| 3044 |
|
| 3045 |
-
static_assert(GGML_OP_COUNT ==
|
| 3046 |
|
| 3047 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 3048 |
|
|
@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|
| 3061 |
"SILU",
|
| 3062 |
"HARDSWISH",
|
| 3063 |
"HARDSIGMOID",
|
|
|
|
| 3064 |
};
|
| 3065 |
|
| 3066 |
-
static_assert(GGML_UNARY_OP_COUNT ==
|
| 3067 |
|
| 3068 |
|
| 3069 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
|
@@ -5466,6 +5470,19 @@ struct ggml_tensor * ggml_hardsigmoid(
|
|
| 5466 |
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
|
| 5467 |
}
|
| 5468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5469 |
// ggml_norm
|
| 5470 |
|
| 5471 |
static struct ggml_tensor * ggml_norm_impl(
|
|
@@ -7734,6 +7751,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
|
| 7734 |
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
| 7735 |
}
|
| 7736 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7737 |
// ggml_unary
|
| 7738 |
|
| 7739 |
static struct ggml_tensor * ggml_unary_impl(
|
|
@@ -12114,6 +12184,48 @@ static void ggml_compute_forward_hardsigmoid(
|
|
| 12114 |
}
|
| 12115 |
}
|
| 12116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12117 |
|
| 12118 |
// ggml_compute_forward_norm
|
| 12119 |
|
|
@@ -16692,6 +16804,10 @@ static void ggml_compute_forward_unary(
|
|
| 16692 |
{
|
| 16693 |
ggml_compute_forward_hardsigmoid(params, dst);
|
| 16694 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16695 |
default:
|
| 16696 |
{
|
| 16697 |
GGML_ABORT("fatal error");
|
|
@@ -16827,6 +16943,96 @@ static void ggml_compute_forward_add_rel_pos(
|
|
| 16827 |
}
|
| 16828 |
}
|
| 16829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16830 |
// ggml_compute_forward_map_unary
|
| 16831 |
|
| 16832 |
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -17478,6 +17684,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 17478 |
{
|
| 17479 |
ggml_compute_forward_add_rel_pos(params, tensor);
|
| 17480 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17481 |
case GGML_OP_MAP_UNARY:
|
| 17482 |
{
|
| 17483 |
ggml_unary_op_f32_t fun;
|
|
@@ -18591,12 +18801,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18591 |
zero_table);
|
| 18592 |
}
|
| 18593 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18594 |
default:
|
| 18595 |
GGML_ABORT("fatal error");
|
| 18596 |
}
|
| 18597 |
} break;
|
| 18598 |
case GGML_OP_GET_REL_POS:
|
| 18599 |
case GGML_OP_ADD_REL_POS:
|
|
|
|
| 18600 |
case GGML_OP_MAP_UNARY:
|
| 18601 |
case GGML_OP_MAP_BINARY:
|
| 18602 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -19021,6 +19241,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 19021 |
case GGML_UNARY_OP_SIGMOID:
|
| 19022 |
case GGML_UNARY_OP_HARDSWISH:
|
| 19023 |
case GGML_UNARY_OP_HARDSIGMOID:
|
|
|
|
| 19024 |
{
|
| 19025 |
n_tasks = 1;
|
| 19026 |
} break;
|
|
@@ -19112,6 +19333,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 19112 |
case GGML_OP_WIN_PART:
|
| 19113 |
case GGML_OP_WIN_UNPART:
|
| 19114 |
case GGML_OP_GET_REL_POS:
|
|
|
|
| 19115 |
case GGML_OP_MAP_UNARY:
|
| 19116 |
case GGML_OP_MAP_BINARY:
|
| 19117 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 2422 |
// TODO: optimize performance
|
| 2423 |
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
| 2424 |
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
| 2425 |
+
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
|
| 2426 |
|
| 2427 |
static const float GELU_COEF_A = 0.044715f;
|
| 2428 |
static const float GELU_QUICK_COEF = -1.702f;
|
|
|
|
| 2933 |
"WIN_UNPART",
|
| 2934 |
"GET_REL_POS",
|
| 2935 |
"ADD_REL_POS",
|
| 2936 |
+
"RWKV_WKV",
|
| 2937 |
|
| 2938 |
"UNARY",
|
| 2939 |
|
|
|
|
| 2952 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2953 |
};
|
| 2954 |
|
| 2955 |
+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
| 2956 |
|
| 2957 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2958 |
"none",
|
|
|
|
| 3026 |
"win_unpart(x)",
|
| 3027 |
"get_rel_pos(x)",
|
| 3028 |
"add_rel_pos(x)",
|
| 3029 |
+
"rwkv_wkv(k, v, r, tf, td, s)",
|
| 3030 |
|
| 3031 |
"unary(x)",
|
| 3032 |
|
|
|
|
| 3045 |
"cross_entropy_loss_back(x,y)",
|
| 3046 |
};
|
| 3047 |
|
| 3048 |
+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
| 3049 |
|
| 3050 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 3051 |
|
|
|
|
| 3064 |
"SILU",
|
| 3065 |
"HARDSWISH",
|
| 3066 |
"HARDSIGMOID",
|
| 3067 |
+
"EXP",
|
| 3068 |
};
|
| 3069 |
|
| 3070 |
+
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
|
| 3071 |
|
| 3072 |
|
| 3073 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
|
|
|
| 5470 |
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
|
| 5471 |
}
|
| 5472 |
|
| 5473 |
+
// ggml exp
|
| 5474 |
+
struct ggml_tensor * ggml_exp(
|
| 5475 |
+
struct ggml_context * ctx,
|
| 5476 |
+
struct ggml_tensor * a) {
|
| 5477 |
+
return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
|
| 5478 |
+
}
|
| 5479 |
+
|
| 5480 |
+
struct ggml_tensor * ggml_exp_inplace(
|
| 5481 |
+
struct ggml_context * ctx,
|
| 5482 |
+
struct ggml_tensor * a) {
|
| 5483 |
+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
|
| 5484 |
+
}
|
| 5485 |
+
|
| 5486 |
// ggml_norm
|
| 5487 |
|
| 5488 |
static struct ggml_tensor * ggml_norm_impl(
|
|
|
|
| 7751 |
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
| 7752 |
}
|
| 7753 |
|
| 7754 |
+
// ggml_rwkv_wkv
|
| 7755 |
+
|
| 7756 |
+
struct ggml_tensor * ggml_rwkv_wkv(
|
| 7757 |
+
struct ggml_context * ctx,
|
| 7758 |
+
struct ggml_tensor * k,
|
| 7759 |
+
struct ggml_tensor * v,
|
| 7760 |
+
struct ggml_tensor * r,
|
| 7761 |
+
struct ggml_tensor * tf,
|
| 7762 |
+
struct ggml_tensor * td,
|
| 7763 |
+
struct ggml_tensor * state) {
|
| 7764 |
+
GGML_ASSERT(ggml_is_contiguous(k));
|
| 7765 |
+
GGML_ASSERT(ggml_is_contiguous(v));
|
| 7766 |
+
GGML_ASSERT(ggml_is_contiguous(r));
|
| 7767 |
+
GGML_ASSERT(ggml_is_contiguous(tf));
|
| 7768 |
+
GGML_ASSERT(ggml_is_contiguous(td));
|
| 7769 |
+
GGML_ASSERT(ggml_is_contiguous(state));
|
| 7770 |
+
|
| 7771 |
+
const int64_t S = k->ne[0];
|
| 7772 |
+
const int64_t H = k->ne[2];
|
| 7773 |
+
const int64_t n_tokens = k->ne[3];
|
| 7774 |
+
const int64_t n_seqs = state->ne[1];
|
| 7775 |
+
{
|
| 7776 |
+
GGML_ASSERT(k->ne[1] == 1);
|
| 7777 |
+
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
| 7778 |
+
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
| 7779 |
+
// TODO: RWKV v4 and v5
|
| 7780 |
+
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
| 7781 |
+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
| 7782 |
+
}
|
| 7783 |
+
|
| 7784 |
+
bool is_node = false;
|
| 7785 |
+
|
| 7786 |
+
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
|
| 7787 |
+
GGML_ABORT("fatal error"); // TODO: implement backward
|
| 7788 |
+
is_node = true;
|
| 7789 |
+
}
|
| 7790 |
+
|
| 7791 |
+
// concat output and new_state
|
| 7792 |
+
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
| 7793 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 7794 |
+
|
| 7795 |
+
result->op = GGML_OP_RWKV_WKV;
|
| 7796 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 7797 |
+
result->src[0] = k;
|
| 7798 |
+
result->src[1] = v;
|
| 7799 |
+
result->src[2] = r;
|
| 7800 |
+
result->src[3] = tf;
|
| 7801 |
+
result->src[4] = td;
|
| 7802 |
+
result->src[5] = state;
|
| 7803 |
+
|
| 7804 |
+
return result;
|
| 7805 |
+
}
|
| 7806 |
+
|
| 7807 |
// ggml_unary
|
| 7808 |
|
| 7809 |
static struct ggml_tensor * ggml_unary_impl(
|
|
|
|
| 12184 |
}
|
| 12185 |
}
|
| 12186 |
|
| 12187 |
+
static void ggml_compute_forward_exp_f32(
|
| 12188 |
+
const struct ggml_compute_params * params,
|
| 12189 |
+
struct ggml_tensor * dst) {
|
| 12190 |
+
|
| 12191 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 12192 |
+
|
| 12193 |
+
if (params->ith != 0) {
|
| 12194 |
+
return;
|
| 12195 |
+
}
|
| 12196 |
+
|
| 12197 |
+
assert(ggml_is_contiguous_1(src0));
|
| 12198 |
+
assert(ggml_is_contiguous_1(dst));
|
| 12199 |
+
assert(ggml_are_same_shape(src0, dst));
|
| 12200 |
+
|
| 12201 |
+
const int n = ggml_nrows(src0);
|
| 12202 |
+
const int nc = src0->ne[0];
|
| 12203 |
+
|
| 12204 |
+
for (int i = 0; i < n; i++) {
|
| 12205 |
+
ggml_vec_exp_f32(nc,
|
| 12206 |
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
| 12207 |
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
| 12208 |
+
}
|
| 12209 |
+
}
|
| 12210 |
+
|
| 12211 |
+
static void ggml_compute_forward_exp(
|
| 12212 |
+
const struct ggml_compute_params * params,
|
| 12213 |
+
struct ggml_tensor * dst) {
|
| 12214 |
+
|
| 12215 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 12216 |
+
|
| 12217 |
+
switch (src0->type) {
|
| 12218 |
+
case GGML_TYPE_F32:
|
| 12219 |
+
{
|
| 12220 |
+
ggml_compute_forward_exp_f32(params, dst);
|
| 12221 |
+
} break;
|
| 12222 |
+
default:
|
| 12223 |
+
{
|
| 12224 |
+
GGML_ABORT("fatal error");
|
| 12225 |
+
}
|
| 12226 |
+
}
|
| 12227 |
+
}
|
| 12228 |
+
|
| 12229 |
|
| 12230 |
// ggml_compute_forward_norm
|
| 12231 |
|
|
|
|
| 16804 |
{
|
| 16805 |
ggml_compute_forward_hardsigmoid(params, dst);
|
| 16806 |
} break;
|
| 16807 |
+
case GGML_UNARY_OP_EXP:
|
| 16808 |
+
{
|
| 16809 |
+
ggml_compute_forward_exp(params, dst);
|
| 16810 |
+
} break;
|
| 16811 |
default:
|
| 16812 |
{
|
| 16813 |
GGML_ABORT("fatal error");
|
|
|
|
| 16943 |
}
|
| 16944 |
}
|
| 16945 |
|
| 16946 |
+
// ggml_compute_forward_rwkv_wkv
|
| 16947 |
+
|
| 16948 |
+
static void ggml_compute_forward_rwkv_wkv_f32(
|
| 16949 |
+
const struct ggml_compute_params * params,
|
| 16950 |
+
struct ggml_tensor * dst) {
|
| 16951 |
+
const size_t T = dst->src[1]->ne[3];
|
| 16952 |
+
const size_t C = dst->ne[0];
|
| 16953 |
+
const size_t H = dst->src[1]->ne[2];
|
| 16954 |
+
const size_t n_seqs = dst->src[5]->ne[1];
|
| 16955 |
+
|
| 16956 |
+
float * dst_data = (float *) dst->data;
|
| 16957 |
+
float * state = ((float *) dst->data) + C * T;
|
| 16958 |
+
|
| 16959 |
+
if (params->ith != 0) {
|
| 16960 |
+
return;
|
| 16961 |
+
}
|
| 16962 |
+
|
| 16963 |
+
memset(dst_data, 0, T * C * sizeof(float));
|
| 16964 |
+
|
| 16965 |
+
float * k = (float *) dst->src[0]->data;
|
| 16966 |
+
float * v = (float *) dst->src[1]->data;
|
| 16967 |
+
float * r = (float *) dst->src[2]->data;
|
| 16968 |
+
float * time_faaaa = (float *) dst->src[3]->data;
|
| 16969 |
+
float * time_decay = (float *) dst->src[4]->data;
|
| 16970 |
+
|
| 16971 |
+
size_t t_stride = H * (C / H);
|
| 16972 |
+
|
| 16973 |
+
size_t h_stride = C / H;
|
| 16974 |
+
size_t h_stride_2d = (C / H) * (C / H);
|
| 16975 |
+
|
| 16976 |
+
// basically fused operations:
|
| 16977 |
+
// dst = r @ (time_faaaa * (k @ v) + state),
|
| 16978 |
+
// state = time_decay * state + (k @ v),
|
| 16979 |
+
// recursive through each token
|
| 16980 |
+
for (size_t t = 0; t < T; t++) {
|
| 16981 |
+
size_t t_offset = t * t_stride;
|
| 16982 |
+
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
| 16983 |
+
float * state_cur = state + state_offset;
|
| 16984 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
| 16985 |
+
|
| 16986 |
+
for (size_t h = 0; h < H; h++) {
|
| 16987 |
+
size_t h_offset = h * h_stride;
|
| 16988 |
+
size_t t_h_offset = t_offset + h_offset;
|
| 16989 |
+
size_t h_2d_offset = h * h_stride_2d;
|
| 16990 |
+
|
| 16991 |
+
for (size_t i = 0; i < C / H; i++) {
|
| 16992 |
+
size_t t_h_i_offset = t_h_offset + i;
|
| 16993 |
+
size_t h_i_offset = h_offset + i;
|
| 16994 |
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 16995 |
+
|
| 16996 |
+
float k_val = k[t_h_i_offset];
|
| 16997 |
+
float r_val = r[t_h_i_offset];
|
| 16998 |
+
float time_faaaa_val = time_faaaa[h_i_offset];
|
| 16999 |
+
// RWKV v6: different time_decay for each token.
|
| 17000 |
+
float time_decay_val = time_decay[t_h_i_offset];
|
| 17001 |
+
|
| 17002 |
+
for (size_t j = 0; j < C / H; j ++) {
|
| 17003 |
+
size_t t_h_j_offset = t_h_offset + j;
|
| 17004 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 17005 |
+
|
| 17006 |
+
float v_val = v[t_h_j_offset];
|
| 17007 |
+
float kv_val = v_val * k_val;
|
| 17008 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 17009 |
+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
| 17010 |
+
dst_data[t_h_j_offset] += temp_val * r_val;
|
| 17011 |
+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
| 17012 |
+
}
|
| 17013 |
+
}
|
| 17014 |
+
}
|
| 17015 |
+
}
|
| 17016 |
+
}
|
| 17017 |
+
|
| 17018 |
+
static void ggml_compute_forward_rwkv_wkv(
|
| 17019 |
+
const struct ggml_compute_params * params,
|
| 17020 |
+
struct ggml_tensor * dst) {
|
| 17021 |
+
|
| 17022 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 17023 |
+
|
| 17024 |
+
switch (src0->type) {
|
| 17025 |
+
case GGML_TYPE_F32:
|
| 17026 |
+
{
|
| 17027 |
+
ggml_compute_forward_rwkv_wkv_f32(params, dst);
|
| 17028 |
+
} break;
|
| 17029 |
+
default:
|
| 17030 |
+
{
|
| 17031 |
+
GGML_ABORT("fatal error");
|
| 17032 |
+
}
|
| 17033 |
+
}
|
| 17034 |
+
}
|
| 17035 |
+
|
| 17036 |
// ggml_compute_forward_map_unary
|
| 17037 |
|
| 17038 |
static void ggml_compute_forward_map_unary_f32(
|
|
|
|
| 17684 |
{
|
| 17685 |
ggml_compute_forward_add_rel_pos(params, tensor);
|
| 17686 |
} break;
|
| 17687 |
+
case GGML_OP_RWKV_WKV:
|
| 17688 |
+
{
|
| 17689 |
+
ggml_compute_forward_rwkv_wkv(params, tensor);
|
| 17690 |
+
} break;
|
| 17691 |
case GGML_OP_MAP_UNARY:
|
| 17692 |
{
|
| 17693 |
ggml_unary_op_f32_t fun;
|
|
|
|
| 18801 |
zero_table);
|
| 18802 |
}
|
| 18803 |
} break;
|
| 18804 |
+
case GGML_UNARY_OP_EXP:
|
| 18805 |
+
{
|
| 18806 |
+
if (src0->grad) {
|
| 18807 |
+
src0->grad = ggml_add_or_set(ctx,
|
| 18808 |
+
src0->grad,
|
| 18809 |
+
ggml_mul(ctx, tensor, tensor->grad),
|
| 18810 |
+
zero_table);
|
| 18811 |
+
}
|
| 18812 |
+
} break;
|
| 18813 |
default:
|
| 18814 |
GGML_ABORT("fatal error");
|
| 18815 |
}
|
| 18816 |
} break;
|
| 18817 |
case GGML_OP_GET_REL_POS:
|
| 18818 |
case GGML_OP_ADD_REL_POS:
|
| 18819 |
+
case GGML_OP_RWKV_WKV:
|
| 18820 |
case GGML_OP_MAP_UNARY:
|
| 18821 |
case GGML_OP_MAP_BINARY:
|
| 18822 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 19241 |
case GGML_UNARY_OP_SIGMOID:
|
| 19242 |
case GGML_UNARY_OP_HARDSWISH:
|
| 19243 |
case GGML_UNARY_OP_HARDSIGMOID:
|
| 19244 |
+
case GGML_UNARY_OP_EXP:
|
| 19245 |
{
|
| 19246 |
n_tasks = 1;
|
| 19247 |
} break;
|
|
|
|
| 19333 |
case GGML_OP_WIN_PART:
|
| 19334 |
case GGML_OP_WIN_UNPART:
|
| 19335 |
case GGML_OP_GET_REL_POS:
|
| 19336 |
+
case GGML_OP_RWKV_WKV:
|
| 19337 |
case GGML_OP_MAP_UNARY:
|
| 19338 |
case GGML_OP_MAP_BINARY:
|
| 19339 |
case GGML_OP_MAP_CUSTOM1_F32:
|