1
0

json_schema_pydantic_example.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Usage:
  2. #! ./llama-server -m some-model.gguf &
  3. #! pip install pydantic
  4. #! python json_schema_pydantic_example.py
  5. from pydantic import BaseModel, Field, TypeAdapter
  6. from annotated_types import MinLen
  7. from typing import Annotated, List, Optional
  8. import json, requests
  9. if True:
  10. def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
  11. '''
  12. Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
  13. (llama.cpp server, llama-cpp-python, Anyscale / Together...)
  14. The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
  15. '''
  16. response_format = None
  17. type_adapter = None
  18. if response_model:
  19. type_adapter = TypeAdapter(response_model)
  20. schema = type_adapter.json_schema()
  21. messages = [{
  22. "role": "system",
  23. "content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
  24. }] + messages
  25. response_format={"type": "json_object", "schema": schema}
  26. data = requests.post(endpoint, headers={"Content-Type": "application/json"},
  27. json=dict(messages=messages, response_format=response_format, **kwargs)).json()
  28. if 'error' in data:
  29. raise Exception(data['error']['message'])
  30. content = data["choices"][0]["message"]["content"]
  31. return type_adapter.validate_json(content) if type_adapter else content
  32. else:
  33. # This alternative branch uses Instructor + OpenAI client lib.
  34. # Instructor support streamed iterable responses, retry & more.
  35. # (see https://python.useinstructor.com/)
  36. #! pip install instructor openai
  37. import instructor, openai
  38. client = instructor.patch(
  39. openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
  40. mode=instructor.Mode.JSON_SCHEMA)
  41. create_completion = client.chat.completions.create
  42. if __name__ == '__main__':
  43. class QAPair(BaseModel):
  44. class Config:
  45. extra = 'forbid' # triggers additionalProperties: false in the JSON schema
  46. question: str
  47. concise_answer: str
  48. justification: str
  49. stars: Annotated[int, Field(ge=1, le=5)]
  50. class PyramidalSummary(BaseModel):
  51. class Config:
  52. extra = 'forbid' # triggers additionalProperties: false in the JSON schema
  53. title: str
  54. summary: str
  55. question_answers: Annotated[List[QAPair], MinLen(2)]
  56. sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
  57. print("# Summary\n", create_completion(
  58. model="...",
  59. response_model=PyramidalSummary,
  60. messages=[{
  61. "role": "user",
  62. "content": f"""
  63. You are a highly efficient corporate document summarizer.
  64. Create a pyramidal summary of an imaginary internal document about our company processes
  65. (starting high-level, going down to each sub sections).
  66. Keep questions short, and answers even shorter (trivia / quizz style).
  67. """
  68. }]))