|
@@ -235,6 +235,8 @@ int main(void) {
|
|
|
|
|
|
|
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
|
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
|
|
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
|
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
|
|
|
|
|
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
|
|
|
|
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
|
|
|
|
|
|
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
|
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
|
|
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
|
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
|