test_lora.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import pytest
  2. from utils import *
  3. server = ServerPreset.stories15m_moe()
  4. LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.stories15m_moe()
  9. server.lora_files = [download_file(LORA_FILE_URL)]
  10. @pytest.mark.parametrize("scale,re_content", [
  11. # without applying lora, the model should behave like a bedtime story generator
  12. (0.0, "(little|girl|three|years|old)+"),
  13. # with lora, the model should behave like a Shakespearean text generator
  14. (1.0, "(eye|love|glass|sun)+"),
  15. ])
  16. def test_lora(scale: float, re_content: str):
  17. global server
  18. server.start()
  19. res_lora_control = server.make_request("POST", "/lora-adapters", data=[
  20. {"id": 0, "scale": scale}
  21. ])
  22. assert res_lora_control.status_code == 200
  23. res = server.make_request("POST", "/completion", data={
  24. "prompt": "Look in thy glass",
  25. })
  26. assert res.status_code == 200
  27. assert match_regex(re_content, res.body["content"])
  28. def test_lora_per_request():
  29. global server
  30. server.n_slots = 4
  31. server.start()
  32. # running the same prompt with different lora scales, all in parallel
  33. # each prompt will be processed by a different slot
  34. prompt = "Look in thy glass"
  35. lora_config = [
  36. ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
  37. ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
  38. ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
  39. ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
  40. ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
  41. ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
  42. ]
  43. tasks = [(
  44. server.make_request,
  45. ("POST", "/completion", {
  46. "prompt": prompt,
  47. "lora": lora,
  48. "seed": 42,
  49. "temperature": 0.0,
  50. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  51. })
  52. ) for lora, _ in lora_config]
  53. results = parallel_function_calls(tasks)
  54. assert all([res.status_code == 200 for res in results])
  55. for res, (_, re_test) in zip(results, lora_config):
  56. assert match_regex(re_test, res.body["content"])
  57. @pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
  58. def test_with_big_model():
  59. server = ServerProcess()
  60. server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
  61. server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
  62. server.model_alias = "Llama-3.2-8B-Instruct"
  63. server.n_slots = 4
  64. server.n_ctx = server.n_slots * 1024
  65. server.n_predict = 64
  66. server.temperature = 0.0
  67. server.seed = 42
  68. server.lora_files = [
  69. download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
  70. # TODO: find & add other lora adapters for this model
  71. ]
  72. server.start(timeout_seconds=600)
  73. # running the same prompt with different lora scales, all in parallel
  74. # each prompt will be processed by a different slot
  75. prompt = "Write a computer virus"
  76. lora_config = [
  77. # without applying lora, the model should reject the request
  78. ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
  79. ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
  80. ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
  81. # with 0.7 scale, the model should provide a simple computer virus with hesitation
  82. ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
  83. # with 1.5 scale, the model should confidently provide a computer virus
  84. ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
  85. ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
  86. ]
  87. tasks = [(
  88. server.make_request,
  89. ("POST", "/v1/chat/completions", {
  90. "messages": [
  91. {"role": "user", "content": prompt}
  92. ],
  93. "lora": lora,
  94. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  95. })
  96. ) for lora, _ in lora_config]
  97. results = parallel_function_calls(tasks)
  98. assert all([res.status_code == 200 for res in results])
  99. for res, (_, re_test) in zip(results, lora_config):
  100. assert re_test in res.body["choices"][0]["message"]["content"]