test_infill.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama_infill()
  4. @pytest.fixture(scope="module", autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.tinyllama_infill()
  8. def test_infill_without_input_extra():
  9. global server
  10. server.start()
  11. res = server.make_request("POST", "/infill", data={
  12. "prompt": "Complete this",
  13. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
  14. "input_suffix": "}\n",
  15. })
  16. assert res.status_code == 200
  17. assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
  18. def test_infill_with_input_extra():
  19. global server
  20. server.start()
  21. res = server.make_request("POST", "/infill", data={
  22. "prompt": "Complete this",
  23. "input_extra": [{
  24. "filename": "llama.h",
  25. "text": "LLAMA_API int32_t llama_n_threads();\n"
  26. }],
  27. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
  28. "input_suffix": "}\n",
  29. })
  30. assert res.status_code == 200
  31. assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
  32. @pytest.mark.parametrize("input_extra", [
  33. {},
  34. {"filename": "ok"},
  35. {"filename": 123},
  36. {"filename": 123, "text": "abc"},
  37. {"filename": 123, "text": 456},
  38. ])
  39. def test_invalid_input_extra_req(input_extra):
  40. global server
  41. server.start()
  42. res = server.make_request("POST", "/infill", data={
  43. "prompt": "Complete this",
  44. "input_extra": [input_extra],
  45. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
  46. "input_suffix": "}\n",
  47. })
  48. assert res.status_code == 400
  49. assert "error" in res.body