test_security.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. TEST_API_KEY = "sk-this-is-the-secret-key"
  6. @pytest.fixture(autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.tinyllama2()
  10. server.api_key = TEST_API_KEY
  11. @pytest.mark.parametrize("endpoint", ["/health", "/models"])
  12. def test_access_public_endpoint(endpoint: str):
  13. global server
  14. server.start()
  15. res = server.make_request("GET", endpoint)
  16. assert res.status_code == 200
  17. assert "error" not in res.body
  18. @pytest.mark.parametrize("api_key", [None, "invalid-key"])
  19. def test_incorrect_api_key(api_key: str):
  20. global server
  21. server.start()
  22. res = server.make_request("POST", "/completions", data={
  23. "prompt": "I believe the meaning of life is",
  24. }, headers={
  25. "Authorization": f"Bearer {api_key}" if api_key else None,
  26. })
  27. assert res.status_code == 401
  28. assert "error" in res.body
  29. assert res.body["error"]["type"] == "authentication_error"
  30. def test_correct_api_key():
  31. global server
  32. server.start()
  33. res = server.make_request("POST", "/completions", data={
  34. "prompt": "I believe the meaning of life is",
  35. }, headers={
  36. "Authorization": f"Bearer {TEST_API_KEY}",
  37. })
  38. assert res.status_code == 200
  39. assert "error" not in res.body
  40. assert "content" in res.body
  41. def test_openai_library_correct_api_key():
  42. global server
  43. server.start()
  44. client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
  45. res = client.chat.completions.create(
  46. model="gpt-3.5-turbo",
  47. messages=[
  48. {"role": "system", "content": "You are a chatbot."},
  49. {"role": "user", "content": "What is the meaning of life?"},
  50. ],
  51. )
  52. assert len(res.choices) == 1
  53. @pytest.mark.parametrize("origin,cors_header,cors_header_value", [
  54. ("localhost", "Access-Control-Allow-Origin", "localhost"),
  55. ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
  56. ("origin", "Access-Control-Allow-Credentials", "true"),
  57. ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
  58. ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
  59. ])
  60. def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
  61. global server
  62. server.start()
  63. res = server.make_request("OPTIONS", "/completions", headers={
  64. "Origin": origin,
  65. "Access-Control-Request-Method": "POST",
  66. "Access-Control-Request-Headers": "Authorization",
  67. })
  68. assert res.status_code == 200
  69. assert cors_header in res.headers
  70. assert res.headers[cors_header] == cors_header_value