test_infill.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama_infill()
  4. @pytest.fixture(scope="module", autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.tinyllama_infill()
  8. def test_infill_without_input_extra():
  9. global server
  10. server.start()
  11. res = server.make_request("POST", "/infill", data={
  12. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
  13. "prompt": " int n_threads = llama_",
  14. "input_suffix": "}\n",
  15. })
  16. assert res.status_code == 200
  17. assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
  18. def test_infill_with_input_extra():
  19. global server
  20. server.start()
  21. res = server.make_request("POST", "/infill", data={
  22. "input_extra": [{
  23. "filename": "llama.h",
  24. "text": "LLAMA_API int32_t llama_n_threads();\n"
  25. }],
  26. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
  27. "prompt": " int n_threads = llama_",
  28. "input_suffix": "}\n",
  29. })
  30. assert res.status_code == 200
  31. assert match_regex("(Dad|excited|park)+", res.body["content"])
  32. @pytest.mark.parametrize("input_extra", [
  33. {},
  34. {"filename": "ok"},
  35. {"filename": 123},
  36. {"filename": 123, "text": "abc"},
  37. {"filename": 123, "text": 456},
  38. ])
  39. def test_invalid_input_extra_req(input_extra):
  40. global server
  41. server.start()
  42. res = server.make_request("POST", "/infill", data={
  43. "input_extra": [input_extra],
  44. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
  45. "prompt": " int n_threads = llama_",
  46. "input_suffix": "}\n",
  47. })
  48. assert res.status_code == 400
  49. assert "error" in res.body
  50. @pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
  51. def test_with_qwen_model():
  52. global server
  53. server.model_file = None
  54. server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
  55. server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
  56. server.start(timeout_seconds=600)
  57. res = server.make_request("POST", "/infill", data={
  58. "input_extra": [{
  59. "filename": "llama.h",
  60. "text": "LLAMA_API int32_t llama_n_threads();\n"
  61. }],
  62. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
  63. "prompt": " int n_threads = llama_",
  64. "input_suffix": "}\n",
  65. })
  66. assert res.status_code == 200
  67. assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"