|
|
@@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
|
|
|
int n_dims;
|
|
|
int mode;
|
|
|
int n_ctx;
|
|
|
+ bool ff;
|
|
|
|
|
|
std::string vars() override {
|
|
|
- return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
|
|
|
+ return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
|
|
|
}
|
|
|
|
|
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
|
|
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
|
|
- int n_dims = 10, int mode = 0, int n_ctx = 512)
|
|
|
- : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
|
|
|
+ int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
|
|
|
+ : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
|
|
|
|
|
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
|
|
- ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
|
|
|
+ ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
|
|
+ ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
|
return out;
|
|
|
}
|
|
|
|
|
|
@@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
|
|
|
}
|
|
|
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
|
|
|
} else {
|
|
|
- init_tensor_uniform(t);
|
|
|
+ if (t->ne[0] == n_dims/2) {
|
|
|
+ // frequency factors in the range [0.9f, 1.1f]
|
|
|
+ init_tensor_uniform(t, 0.9f, 1.1f);
|
|
|
+ } else {
|
|
|
+ init_tensor_uniform(t);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
|
|
|
|
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
|
|
- test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
|
|
- test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
|
|
|
- test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B
|
|
|
- test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B
|
|
|
- test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
|
|
|
- test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
|
|
|
- test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
|
|
- test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
|
|
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
|
|
|
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
|
|
|
+ // TODO: ff not supported yet for !neox
|
|
|
+ test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
|
|
|
+ test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
|
|
|
+ test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
|
|
|
+ test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
|
|
|
+
|
|
|
+ for (bool ff : {false, true}) { // freq_factors
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
|
|
|
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|