test_infill.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama_infill()
  4. @pytest.fixture(scope="module", autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.tinyllama_infill()
  8. def test_infill_without_input_extra():
  9. global server
  10. server.start()
  11. res = server.make_request("POST", "/infill", data={
  12. "prompt": "Complete this",
  13. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
  14. "input_suffix": "}\n",
  15. })
  16. assert res.status_code == 200
  17. assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
  18. def test_infill_with_input_extra():
  19. global server
  20. server.start()
  21. res = server.make_request("POST", "/infill", data={
  22. "prompt": "Complete this",
  23. "input_extra": [{
  24. "filename": "llama.h",
  25. "text": "LLAMA_API int32_t llama_n_threads();\n"
  26. }],
  27. "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
  28. "input_suffix": "}\n",
  29. })
  30. assert res.status_code == 200
  31. assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])