1
0

test_security.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. TEST_API_KEY = "sk-this-is-the-secret-key"
  6. @pytest.fixture(autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.tinyllama2()
  10. server.api_key = TEST_API_KEY
  11. @pytest.mark.parametrize("endpoint", ["/health", "/models"])
  12. def test_access_public_endpoint(endpoint: str):
  13. global server
  14. server.start()
  15. res = server.make_request("GET", endpoint)
  16. assert res.status_code == 200
  17. assert "error" not in res.body
  18. @pytest.mark.parametrize("api_key", [None, "invalid-key"])
  19. def test_incorrect_api_key(api_key: str):
  20. global server
  21. server.start()
  22. res = server.make_request("POST", "/completions", data={
  23. "prompt": "I believe the meaning of life is",
  24. }, headers={
  25. "Authorization": f"Bearer {api_key}" if api_key else None,
  26. })
  27. assert res.status_code == 401
  28. assert "error" in res.body
  29. assert res.body["error"]["type"] == "authentication_error"
  30. def test_correct_api_key():
  31. global server
  32. server.start()
  33. res = server.make_request("POST", "/completions", data={
  34. "prompt": "I believe the meaning of life is",
  35. }, headers={
  36. "Authorization": f"Bearer {TEST_API_KEY}",
  37. })
  38. assert res.status_code == 200
  39. assert "error" not in res.body
  40. assert "content" in res.body
  41. def test_correct_api_key_anthropic_header():
  42. global server
  43. server.start()
  44. res = server.make_request("POST", "/completions", data={
  45. "prompt": "I believe the meaning of life is",
  46. }, headers={
  47. "X-Api-Key": TEST_API_KEY,
  48. })
  49. assert res.status_code == 200
  50. assert "error" not in res.body
  51. assert "content" in res.body
  52. def test_openai_library_correct_api_key():
  53. global server
  54. server.start()
  55. client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
  56. res = client.chat.completions.create(
  57. model="gpt-3.5-turbo",
  58. messages=[
  59. {"role": "system", "content": "You are a chatbot."},
  60. {"role": "user", "content": "What is the meaning of life?"},
  61. ],
  62. )
  63. assert len(res.choices) == 1
  64. @pytest.mark.parametrize("origin,cors_header,cors_header_value", [
  65. ("localhost", "Access-Control-Allow-Origin", "localhost"),
  66. ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
  67. ("origin", "Access-Control-Allow-Credentials", "true"),
  68. ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
  69. ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
  70. ])
  71. def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
  72. global server
  73. server.start()
  74. res = server.make_request("OPTIONS", "/completions", headers={
  75. "Origin": origin,
  76. "Access-Control-Request-Method": "POST",
  77. "Access-Control-Request-Headers": "Authorization",
  78. })
  79. assert res.status_code == 200
  80. assert cors_header in res.headers
  81. assert res.headers[cors_header] == cors_header_value