|
|
@@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+// GGML_OP_RWKV_WKV
|
|
|
+struct test_rwkv_wkv : public test_case {
|
|
|
+ const ggml_type type;
|
|
|
+
|
|
|
+ const int64_t head_count;
|
|
|
+ const int64_t head_size;
|
|
|
+ const int64_t n_seq_tokens;
|
|
|
+ const int64_t n_seqs;
|
|
|
+
|
|
|
+ std::string vars() override {
|
|
|
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
|
|
+ }
|
|
|
+
|
|
|
+ test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
|
|
|
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
|
|
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
|
|
+
|
|
|
+ ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
|
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
|
|
|
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
|
|
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
|
|
|
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
|
|
+ ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
|
|
+ ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
|
|
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
|
|
+ ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
|
|
|
+ return out;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
// GGML_OP_MUL_MAT
|
|
|
struct test_mul_mat : public test_case {
|
|
|
const ggml_type type_a;
|
|
|
@@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
|
|
|
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
|
|
|
|
|
+ test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
|
|
|
+ test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
|
|
|
+ test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
|
|
|
+ test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
|
|
|
+
|
|
|
#if 1
|
|
|
for (ggml_type type_a : base_types) {
|
|
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|