test_ctx_shift.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.tinyllama2()
  4. LONG_TEXT = """
  5. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
  6. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
  7. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
  8. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
  9. """.strip()
  10. @pytest.fixture(autouse=True)
  11. def create_server():
  12. global server
  13. server = ServerPreset.tinyllama2()
  14. server.n_ctx = 256
  15. server.n_slots = 2
  16. def test_ctx_shift_enabled():
  17. # the prompt is 301 tokens
  18. # the slot context is 256/2 = 128 tokens
  19. # the prompt is truncated to keep the last 109 tokens
  20. # 64 tokens are generated thanks to shifting the context when it gets full
  21. global server
  22. server.enable_ctx_shift = True
  23. server.start()
  24. res = server.make_request("POST", "/completion", data={
  25. "n_predict": 64,
  26. "prompt": LONG_TEXT,
  27. })
  28. assert res.status_code == 200
  29. assert res.body["timings"]["prompt_n"] == 109
  30. assert res.body["timings"]["predicted_n"] == 64
  31. assert res.body["truncated"] is True
  32. @pytest.mark.parametrize("n_predict,n_token_output,truncated", [
  33. (64, 64, False),
  34. (-1, 120, True),
  35. ])
  36. def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
  37. global server
  38. server.n_predict = -1
  39. server.start()
  40. res = server.make_request("POST", "/completion", data={
  41. "n_predict": n_predict,
  42. "prompt": "Hi how are you",
  43. })
  44. assert res.status_code == 200
  45. assert res.body["timings"]["predicted_n"] == n_token_output
  46. assert res.body["truncated"] == truncated
  47. def test_ctx_shift_disabled_long_prompt():
  48. global server
  49. server.start()
  50. res = server.make_request("POST", "/completion", data={
  51. "n_predict": 64,
  52. "prompt": LONG_TEXT,
  53. })
  54. assert res.status_code != 200
  55. assert "error" in res.body
  56. assert "exceeds the available context size" in res.body["error"]["message"]
  57. def test_ctx_shift_disabled_stream():
  58. global server
  59. server.start()
  60. res = server.make_stream_request("POST", "/v1/completions", data={
  61. "n_predict": 256,
  62. "prompt": "Once",
  63. "stream": True,
  64. })
  65. content = ""
  66. for data in res:
  67. choice = data["choices"][0]
  68. if choice["finish_reason"] == "length":
  69. assert len(content) > 0
  70. else:
  71. assert choice["finish_reason"] is None
  72. content += choice["text"]