pydantic_models_to_grammar.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312
  1. import inspect
  2. import json
  3. from copy import copy
  4. from inspect import isclass, getdoc
  5. from types import NoneType
  6. from docstring_parser import parse
  7. from pydantic import BaseModel, create_model, Field
  8. from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional, _GenericAlias
  9. from enum import Enum
  10. from typing import get_type_hints, Callable
  11. import re
  12. class PydanticDataType(Enum):
  13. """
  14. Defines the data types supported by the grammar_generator.
  15. Attributes:
  16. STRING (str): Represents a string data type.
  17. BOOLEAN (str): Represents a boolean data type.
  18. INTEGER (str): Represents an integer data type.
  19. FLOAT (str): Represents a float data type.
  20. OBJECT (str): Represents an object data type.
  21. ARRAY (str): Represents an array data type.
  22. ENUM (str): Represents an enum data type.
  23. CUSTOM_CLASS (str): Represents a custom class data type.
  24. """
  25. STRING = "string"
  26. TRIPLE_QUOTED_STRING = "triple_quoted_string"
  27. MARKDOWN_CODE_BLOCK = "markdown_code_block"
  28. BOOLEAN = "boolean"
  29. INTEGER = "integer"
  30. FLOAT = "float"
  31. OBJECT = "object"
  32. ARRAY = "array"
  33. ENUM = "enum"
  34. ANY = "any"
  35. NULL = "null"
  36. CUSTOM_CLASS = "custom-class"
  37. CUSTOM_DICT = "custom-dict"
  38. SET = "set"
  39. def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
  40. if isclass(pydantic_type) and issubclass(pydantic_type, str):
  41. return PydanticDataType.STRING.value
  42. elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
  43. return PydanticDataType.BOOLEAN.value
  44. elif isclass(pydantic_type) and issubclass(pydantic_type, int):
  45. return PydanticDataType.INTEGER.value
  46. elif isclass(pydantic_type) and issubclass(pydantic_type, float):
  47. return PydanticDataType.FLOAT.value
  48. elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
  49. return PydanticDataType.ENUM.value
  50. elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
  51. return format_model_and_field_name(pydantic_type.__name__)
  52. elif get_origin(pydantic_type) == list:
  53. element_type = get_args(pydantic_type)[0]
  54. return f"{map_pydantic_type_to_gbnf(element_type)}-list"
  55. elif get_origin(pydantic_type) == set:
  56. element_type = get_args(pydantic_type)[0]
  57. return f"{map_pydantic_type_to_gbnf(element_type)}-set"
  58. elif get_origin(pydantic_type) == Union:
  59. union_types = get_args(pydantic_type)
  60. union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
  61. return f"union-{'-or-'.join(union_rules)}"
  62. elif get_origin(pydantic_type) == Optional:
  63. element_type = get_args(pydantic_type)[0]
  64. return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
  65. elif isclass(pydantic_type):
  66. return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
  67. elif get_origin(pydantic_type) == dict:
  68. key_type, value_type = get_args(pydantic_type)
  69. return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
  70. else:
  71. return "unknown"
  72. def format_model_and_field_name(model_name: str) -> str:
  73. parts = re.findall("[A-Z][^A-Z]*", model_name)
  74. if not parts: # Check if the list is empty
  75. return model_name.lower().replace("_", "-")
  76. return "-".join(part.lower().replace("_", "-") for part in parts)
  77. def generate_list_rule(element_type):
  78. """
  79. Generate a GBNF rule for a list of a given element type.
  80. :param element_type: The type of the elements in the list (e.g., 'string').
  81. :return: A string representing the GBNF rule for a list of the given type.
  82. """
  83. rule_name = f"{map_pydantic_type_to_gbnf(element_type)}-list"
  84. element_rule = map_pydantic_type_to_gbnf(element_type)
  85. list_rule = rf'{rule_name} ::= "[" {element_rule} ("," {element_rule})* "]"'
  86. return list_rule
  87. def get_members_structure(cls, rule_name):
  88. if issubclass(cls, Enum):
  89. # Handle Enum types
  90. members = [f'"\\"{member.value}\\""' for name, member in cls.__members__.items()]
  91. return f"{cls.__name__.lower()} ::= " + " | ".join(members)
  92. if cls.__annotations__ and cls.__annotations__ != {}:
  93. result = f'{rule_name} ::= "{{"'
  94. type_list_rules = []
  95. # Modify this comprehension
  96. members = [
  97. f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
  98. for name, param_type in cls.__annotations__.items()
  99. if name != "self"
  100. ]
  101. result += '"," '.join(members)
  102. result += ' "}"'
  103. return result, type_list_rules
  104. elif rule_name == "custom-class-any":
  105. result = f"{rule_name} ::= "
  106. result += "value"
  107. type_list_rules = []
  108. return result, type_list_rules
  109. else:
  110. init_signature = inspect.signature(cls.__init__)
  111. parameters = init_signature.parameters
  112. result = f'{rule_name} ::= "{{"'
  113. type_list_rules = []
  114. # Modify this comprehension too
  115. members = [
  116. f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param.annotation)}'
  117. for name, param in parameters.items()
  118. if name != "self" and param.annotation != inspect.Parameter.empty
  119. ]
  120. result += '", "'.join(members)
  121. result += ' "}"'
  122. return result, type_list_rules
  123. def regex_to_gbnf(regex_pattern: str) -> str:
  124. """
  125. Translate a basic regex pattern to a GBNF rule.
  126. Note: This function handles only a subset of simple regex patterns.
  127. """
  128. gbnf_rule = regex_pattern
  129. # Translate common regex components to GBNF
  130. gbnf_rule = gbnf_rule.replace("\\d", "[0-9]")
  131. gbnf_rule = gbnf_rule.replace("\\s", "[ \t\n]")
  132. # Handle quantifiers and other regex syntax that is similar in GBNF
  133. # (e.g., '*', '+', '?', character classes)
  134. return gbnf_rule
  135. def generate_gbnf_integer_rules(max_digit=None, min_digit=None):
  136. """
  137. Generate GBNF Integer Rules
  138. Generates GBNF (Generalized Backus-Naur Form) rules for integers based on the given maximum and minimum digits.
  139. Parameters:
  140. max_digit (int): The maximum number of digits for the integer. Default is None.
  141. min_digit (int): The minimum number of digits for the integer. Default is None.
  142. Returns:
  143. integer_rule (str): The identifier for the integer rule generated.
  144. additional_rules (list): A list of additional rules generated based on the given maximum and minimum digits.
  145. """
  146. additional_rules = []
  147. # Define the rule identifier based on max_digit and min_digit
  148. integer_rule = "integer-part"
  149. if max_digit is not None:
  150. integer_rule += f"-max{max_digit}"
  151. if min_digit is not None:
  152. integer_rule += f"-min{min_digit}"
  153. # Handling Integer Rules
  154. if max_digit is not None or min_digit is not None:
  155. # Start with an empty rule part
  156. integer_rule_part = ""
  157. # Add mandatory digits as per min_digit
  158. if min_digit is not None:
  159. integer_rule_part += "[0-9] " * min_digit
  160. # Add optional digits up to max_digit
  161. if max_digit is not None:
  162. optional_digits = max_digit - (min_digit if min_digit is not None else 0)
  163. integer_rule_part += "".join(["[0-9]? " for _ in range(optional_digits)])
  164. # Trim the rule part and append it to additional rules
  165. integer_rule_part = integer_rule_part.strip()
  166. if integer_rule_part:
  167. additional_rules.append(f"{integer_rule} ::= {integer_rule_part}")
  168. return integer_rule, additional_rules
  169. def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None, min_precision=None):
  170. """
  171. Generate GBNF float rules based on the given constraints.
  172. :param max_digit: Maximum number of digits in the integer part (default: None)
  173. :param min_digit: Minimum number of digits in the integer part (default: None)
  174. :param max_precision: Maximum number of digits in the fractional part (default: None)
  175. :param min_precision: Minimum number of digits in the fractional part (default: None)
  176. :return: A tuple containing the float rule and additional rules as a list
  177. Example Usage:
  178. max_digit = 3
  179. min_digit = 1
  180. max_precision = 2
  181. min_precision = 1
  182. generate_gbnf_float_rules(max_digit, min_digit, max_precision, min_precision)
  183. Output:
  184. ('float-3-1-2-1', ['integer-part-max3-min1 ::= [0-9] [0-9] [0-9]?', 'fractional-part-max2-min1 ::= [0-9] [0-9]?', 'float-3-1-2-1 ::= integer-part-max3-min1 "." fractional-part-max2-min
  185. *1'])
  186. Note:
  187. GBNF stands for Generalized Backus-Naur Form, which is a notation technique to specify the syntax of programming languages or other formal grammars.
  188. """
  189. additional_rules = []
  190. # Define the integer part rule
  191. integer_part_rule = (
  192. "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
  193. f"-min{min_digit}" if min_digit is not None else "")
  194. )
  195. # Define the fractional part rule based on precision constraints
  196. fractional_part_rule = "fractional-part"
  197. fractional_rule_part = ""
  198. if max_precision is not None or min_precision is not None:
  199. fractional_part_rule += (f"-max{max_precision}" if max_precision is not None else "") + (
  200. f"-min{min_precision}" if min_precision is not None else ""
  201. )
  202. # Minimum number of digits
  203. fractional_rule_part = "[0-9]" * (min_precision if min_precision is not None else 1)
  204. # Optional additional digits
  205. fractional_rule_part += "".join(
  206. [" [0-9]?"] * ((max_precision - (
  207. min_precision if min_precision is not None else 1)) if max_precision is not None else 0)
  208. )
  209. additional_rules.append(f"{fractional_part_rule} ::= {fractional_rule_part}")
  210. # Define the float rule
  211. float_rule = f"float-{max_digit if max_digit is not None else 'X'}-{min_digit if min_digit is not None else 'X'}-{max_precision if max_precision is not None else 'X'}-{min_precision if min_precision is not None else 'X'}"
  212. additional_rules.append(f'{float_rule} ::= {integer_part_rule} "." {fractional_part_rule}')
  213. # Generating the integer part rule definition, if necessary
  214. if max_digit is not None or min_digit is not None:
  215. integer_rule_part = "[0-9]"
  216. if min_digit is not None and min_digit > 1:
  217. integer_rule_part += " [0-9]" * (min_digit - 1)
  218. if max_digit is not None:
  219. integer_rule_part += "".join([" [0-9]?"] * (max_digit - (min_digit if min_digit is not None else 1)))
  220. additional_rules.append(f"{integer_part_rule} ::= {integer_rule_part.strip()}")
  221. return float_rule, additional_rules
  222. def generate_gbnf_rule_for_type(
  223. model_name, field_name, field_type, is_optional, processed_models, created_rules, field_info=None
  224. ) -> Tuple[str, list]:
  225. """
  226. Generate GBNF rule for a given field type.
  227. :param model_name: Name of the model.
  228. :param field_name: Name of the field.
  229. :param field_type: Type of the field.
  230. :param is_optional: Whether the field is optional.
  231. :param processed_models: List of processed models.
  232. :param created_rules: List of created rules.
  233. :param field_info: Additional information about the field (optional).
  234. :return: Tuple containing the GBNF type and a list of additional rules.
  235. :rtype: Tuple[str, list]
  236. """
  237. rules = []
  238. field_name = format_model_and_field_name(field_name)
  239. gbnf_type = map_pydantic_type_to_gbnf(field_type)
  240. if isclass(field_type) and issubclass(field_type, BaseModel):
  241. nested_model_name = format_model_and_field_name(field_type.__name__)
  242. nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
  243. rules.extend(nested_model_rules)
  244. gbnf_type, rules = nested_model_name, rules
  245. elif isclass(field_type) and issubclass(field_type, Enum):
  246. enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
  247. enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
  248. rules.append(enum_rule)
  249. gbnf_type, rules = model_name + "-" + field_name, rules
  250. elif get_origin(field_type) == list: # Array
  251. element_type = get_args(field_type)[0]
  252. element_rule_name, additional_rules = generate_gbnf_rule_for_type(
  253. model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
  254. )
  255. rules.extend(additional_rules)
  256. array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """
  257. rules.append(array_rule)
  258. gbnf_type, rules = model_name + "-" + field_name, rules
  259. elif get_origin(field_type) == set or field_type == set: # Array
  260. element_type = get_args(field_type)[0]
  261. element_rule_name, additional_rules = generate_gbnf_rule_for_type(
  262. model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
  263. )
  264. rules.extend(additional_rules)
  265. array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """
  266. rules.append(array_rule)
  267. gbnf_type, rules = model_name + "-" + field_name, rules
  268. elif gbnf_type.startswith("custom-class-"):
  269. nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
  270. rules.append(nested_model_rules)
  271. elif gbnf_type.startswith("custom-dict-"):
  272. key_type, value_type = get_args(field_type)
  273. additional_key_type, additional_key_rules = generate_gbnf_rule_for_type(
  274. model_name, f"{field_name}-key-type", key_type, is_optional, processed_models, created_rules
  275. )
  276. additional_value_type, additional_value_rules = generate_gbnf_rule_for_type(
  277. model_name, f"{field_name}-value-type", value_type, is_optional, processed_models, created_rules
  278. )
  279. gbnf_type = rf'{gbnf_type} ::= "{{" ( {additional_key_type} ": " {additional_value_type} ("," "\n" ws {additional_key_type} ":" {additional_value_type})* )? "}}" '
  280. rules.extend(additional_key_rules)
  281. rules.extend(additional_value_rules)
  282. elif gbnf_type.startswith("union-"):
  283. union_types = get_args(field_type)
  284. union_rules = []
  285. for union_type in union_types:
  286. if isinstance(union_type, _GenericAlias):
  287. union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
  288. model_name, field_name, union_type, False, processed_models, created_rules
  289. )
  290. union_rules.append(union_gbnf_type)
  291. rules.extend(union_rules_list)
  292. elif not issubclass(union_type, NoneType):
  293. union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
  294. model_name, field_name, union_type, False, processed_models, created_rules
  295. )
  296. union_rules.append(union_gbnf_type)
  297. rules.extend(union_rules_list)
  298. # Defining the union grammar rule separately
  299. if len(union_rules) == 1:
  300. union_grammar_rule = f"{model_name}-{field_name}-optional ::= {' | '.join(union_rules)} | null"
  301. else:
  302. union_grammar_rule = f"{model_name}-{field_name}-union ::= {' | '.join(union_rules)}"
  303. rules.append(union_grammar_rule)
  304. if len(union_rules) == 1:
  305. gbnf_type = f"{model_name}-{field_name}-optional"
  306. else:
  307. gbnf_type = f"{model_name}-{field_name}-union"
  308. elif isclass(field_type) and issubclass(field_type, str):
  309. if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
  310. triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
  311. markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
  312. gbnf_type = PydanticDataType.TRIPLE_QUOTED_STRING.value if triple_quoted_string else PydanticDataType.STRING.value
  313. gbnf_type = PydanticDataType.MARKDOWN_CODE_BLOCK.value if markdown_string else gbnf_type
  314. elif field_info and hasattr(field_info, "pattern"):
  315. # Convert regex pattern to grammar rule
  316. regex_pattern = field_info.regex.pattern
  317. gbnf_type = f"pattern-{field_name} ::= {regex_to_gbnf(regex_pattern)}"
  318. else:
  319. gbnf_type = PydanticDataType.STRING.value
  320. elif (
  321. isclass(field_type)
  322. and issubclass(field_type, float)
  323. and field_info
  324. and hasattr(field_info, "json_schema_extra")
  325. and field_info.json_schema_extra is not None
  326. ):
  327. # Retrieve precision attributes for floats
  328. max_precision = (
  329. field_info.json_schema_extra.get("max_precision") if field_info and hasattr(field_info,
  330. "json_schema_extra") else None
  331. )
  332. min_precision = (
  333. field_info.json_schema_extra.get("min_precision") if field_info and hasattr(field_info,
  334. "json_schema_extra") else None
  335. )
  336. max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info,
  337. "json_schema_extra") else None
  338. min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info,
  339. "json_schema_extra") else None
  340. # Generate GBNF rule for float with given attributes
  341. gbnf_type, rules = generate_gbnf_float_rules(
  342. max_digit=max_digits, min_digit=min_digits, max_precision=max_precision, min_precision=min_precision
  343. )
  344. elif (
  345. isclass(field_type)
  346. and issubclass(field_type, int)
  347. and field_info
  348. and hasattr(field_info, "json_schema_extra")
  349. and field_info.json_schema_extra is not None
  350. ):
  351. # Retrieve digit attributes for integers
  352. max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info,
  353. "json_schema_extra") else None
  354. min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info,
  355. "json_schema_extra") else None
  356. # Generate GBNF rule for integer with given attributes
  357. gbnf_type, rules = generate_gbnf_integer_rules(max_digit=max_digits, min_digit=min_digits)
  358. else:
  359. gbnf_type, rules = gbnf_type, []
  360. if gbnf_type not in created_rules:
  361. return gbnf_type, rules
  362. else:
  363. if gbnf_type in created_rules:
  364. return gbnf_type, rules
  365. def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created_rules: dict) -> (list, bool, bool):
  366. """
  367. Generate GBnF Grammar
  368. Generates a GBnF grammar for a given model.
  369. :param model: A Pydantic model class to generate the grammar for. Must be a subclass of BaseModel.
  370. :param processed_models: A set of already processed models to prevent infinite recursion.
  371. :param created_rules: A dict containing already created rules to prevent duplicates.
  372. :return: A list of GBnF grammar rules in string format. And two booleans indicating if an extra markdown or triple quoted string is in the grammar.
  373. Example Usage:
  374. ```
  375. model = MyModel
  376. processed_models = set()
  377. created_rules = dict()
  378. gbnf_grammar = generate_gbnf_grammar(model, processed_models, created_rules)
  379. ```
  380. """
  381. if model in processed_models:
  382. return []
  383. processed_models.add(model)
  384. model_name = format_model_and_field_name(model.__name__)
  385. if not issubclass(model, BaseModel):
  386. # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
  387. if hasattr(model, "__annotations__") and model.__annotations__:
  388. model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
  389. else:
  390. init_signature = inspect.signature(model.__init__)
  391. parameters = init_signature.parameters
  392. model_fields = {name: (param.annotation, param.default) for name, param in parameters.items() if
  393. name != "self"}
  394. else:
  395. # For Pydantic models, use model_fields and check for ellipsis (required fields)
  396. model_fields = model.__annotations__
  397. model_rule_parts = []
  398. nested_rules = []
  399. has_markdown_code_block = False
  400. has_triple_quoted_string = False
  401. look_for_markdown_code_block = False
  402. look_for_triple_quoted_string = False
  403. for field_name, field_info in model_fields.items():
  404. if not issubclass(model, BaseModel):
  405. field_type, default_value = field_info
  406. # Check if the field is optional (not required)
  407. is_optional = (default_value is not inspect.Parameter.empty) and (default_value is not Ellipsis)
  408. else:
  409. field_type = field_info
  410. field_info = model.model_fields[field_name]
  411. is_optional = field_info.is_required is False and get_origin(field_type) is Optional
  412. rule_name, additional_rules = generate_gbnf_rule_for_type(
  413. model_name, format_model_and_field_name(field_name), field_type, is_optional, processed_models,
  414. created_rules, field_info
  415. )
  416. look_for_markdown_code_block = True if rule_name == "markdown_code_block" else False
  417. look_for_triple_quoted_string = True if rule_name == "triple_quoted_string" else False
  418. if not look_for_markdown_code_block and not look_for_triple_quoted_string:
  419. if rule_name not in created_rules:
  420. created_rules[rule_name] = additional_rules
  421. model_rule_parts.append(f' ws "\\"{field_name}\\"" ":" ws {rule_name}') # Adding escaped quotes
  422. nested_rules.extend(additional_rules)
  423. else:
  424. has_triple_quoted_string = look_for_triple_quoted_string
  425. has_markdown_code_block = look_for_markdown_code_block
  426. fields_joined = r' "," "\n" '.join(model_rule_parts)
  427. model_rule = rf'{model_name} ::= "{{" "\n" {fields_joined} "\n" ws "}}"'
  428. has_special_string = False
  429. if has_triple_quoted_string:
  430. model_rule += '"\\n" ws "}"'
  431. model_rule += '"\\n" triple-quoted-string'
  432. has_special_string = True
  433. if has_markdown_code_block:
  434. model_rule += '"\\n" ws "}"'
  435. model_rule += '"\\n" markdown-code-block'
  436. has_special_string = True
  437. all_rules = [model_rule] + nested_rules
  438. return all_rules, has_special_string
  439. def generate_gbnf_grammar_from_pydantic_models(
  440. models: List[Type[BaseModel]], outer_object_name: str = None, outer_object_content: str = None,
  441. list_of_outputs: bool = False
  442. ) -> str:
  443. """
  444. Generate GBNF Grammar from Pydantic Models.
  445. This method takes a list of Pydantic models and uses them to generate a GBNF grammar string. The generated grammar string can be used for parsing and validating data using the generated
  446. * grammar.
  447. Args:
  448. models (List[Type[BaseModel]]): A list of Pydantic models to generate the grammar from.
  449. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  450. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  451. list_of_outputs (str, optional): Allows a list of output objects
  452. Returns:
  453. str: The generated GBNF grammar string.
  454. Examples:
  455. models = [UserModel, PostModel]
  456. grammar = generate_gbnf_grammar_from_pydantic(models)
  457. print(grammar)
  458. # Output:
  459. # root ::= UserModel | PostModel
  460. # ...
  461. """
  462. processed_models = set()
  463. all_rules = []
  464. created_rules = {}
  465. if outer_object_name is None:
  466. for model in models:
  467. model_rules, _ = generate_gbnf_grammar(model, processed_models, created_rules)
  468. all_rules.extend(model_rules)
  469. if list_of_outputs:
  470. root_rule = r'root ::= (" "| "\n") "[" ws grammar-models ("," ws grammar-models)* ws "]"' + "\n"
  471. else:
  472. root_rule = r'root ::= (" "| "\n") grammar-models' + "\n"
  473. root_rule += "grammar-models ::= " + " | ".join(
  474. [format_model_and_field_name(model.__name__) for model in models])
  475. all_rules.insert(0, root_rule)
  476. return "\n".join(all_rules)
  477. elif outer_object_name is not None:
  478. if list_of_outputs:
  479. root_rule = (
  480. rf'root ::= (" "| "\n") "[" ws {format_model_and_field_name(outer_object_name)} ("," ws {format_model_and_field_name(outer_object_name)})* ws "]"'
  481. + "\n"
  482. )
  483. else:
  484. root_rule = f"root ::= {format_model_and_field_name(outer_object_name)}\n"
  485. model_rule = (
  486. rf'{format_model_and_field_name(outer_object_name)} ::= (" "| "\n") "{{" ws "\"{outer_object_name}\"" ":" ws grammar-models'
  487. )
  488. fields_joined = " | ".join(
  489. [rf"{format_model_and_field_name(model.__name__)}-grammar-model" for model in models])
  490. grammar_model_rules = f"\ngrammar-models ::= {fields_joined}"
  491. mod_rules = []
  492. for model in models:
  493. mod_rule = rf"{format_model_and_field_name(model.__name__)}-grammar-model ::= "
  494. mod_rule += (
  495. rf'"\"{model.__name__}\"" "," ws "\"{outer_object_content}\"" ":" ws {format_model_and_field_name(model.__name__)}' + "\n"
  496. )
  497. mod_rules.append(mod_rule)
  498. grammar_model_rules += "\n" + "\n".join(mod_rules)
  499. for model in models:
  500. model_rules, has_special_string = generate_gbnf_grammar(model, processed_models,
  501. created_rules)
  502. if not has_special_string:
  503. model_rules[0] += r'"\n" ws "}"'
  504. all_rules.extend(model_rules)
  505. all_rules.insert(0, root_rule + model_rule + grammar_model_rules)
  506. return "\n".join(all_rules)
  507. def get_primitive_grammar(grammar):
  508. """
  509. Returns the needed GBNF primitive grammar for a given GBNF grammar string.
  510. Args:
  511. grammar (str): The string containing the GBNF grammar.
  512. Returns:
  513. str: GBNF primitive grammar string.
  514. """
  515. type_list = []
  516. if "string-list" in grammar:
  517. type_list.append(str)
  518. if "boolean-list" in grammar:
  519. type_list.append(bool)
  520. if "integer-list" in grammar:
  521. type_list.append(int)
  522. if "float-list" in grammar:
  523. type_list.append(float)
  524. additional_grammar = [generate_list_rule(t) for t in type_list]
  525. primitive_grammar = r"""
  526. boolean ::= "true" | "false"
  527. null ::= "null"
  528. string ::= "\"" (
  529. [^"\\] |
  530. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
  531. )* "\"" ws
  532. ws ::= ([ \t\n] ws)?
  533. float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  534. integer ::= [0-9]+"""
  535. any_block = ""
  536. if "custom-class-any" in grammar:
  537. any_block = """
  538. value ::= object | array | string | number | boolean | null
  539. object ::=
  540. "{" ws (
  541. string ":" ws value
  542. ("," ws string ":" ws value)*
  543. )? "}" ws
  544. array ::=
  545. "[" ws (
  546. value
  547. ("," ws value)*
  548. )? "]" ws
  549. number ::= integer | float"""
  550. markdown_code_block_grammar = ""
  551. if "markdown-code-block" in grammar:
  552. markdown_code_block_grammar = r'''
  553. markdown-code-block ::= opening-triple-ticks markdown-code-block-content closing-triple-ticks
  554. markdown-code-block-content ::= ( [^`] | "`" [^`] | "`" "`" [^`] )*
  555. opening-triple-ticks ::= "```" "python" "\n" | "```" "c" "\n" | "```" "cpp" "\n" | "```" "txt" "\n" | "```" "text" "\n" | "```" "json" "\n" | "```" "javascript" "\n" | "```" "css" "\n" | "```" "html" "\n" | "```" "markdown" "\n"
  556. closing-triple-ticks ::= "```" "\n"'''
  557. if "triple-quoted-string" in grammar:
  558. markdown_code_block_grammar = r"""
  559. triple-quoted-string ::= triple-quotes triple-quoted-string-content triple-quotes
  560. triple-quoted-string-content ::= ( [^'] | "'" [^'] | "'" "'" [^'] )*
  561. triple-quotes ::= "'''" """
  562. return "\n" + "\n".join(additional_grammar) + any_block + primitive_grammar + markdown_code_block_grammar
  563. def generate_markdown_documentation(
  564. pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
  565. documentation_with_field_description=True
  566. ) -> str:
  567. """
  568. Generate markdown documentation for a list of Pydantic models.
  569. Args:
  570. pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
  571. model_prefix (str): Prefix for the model section.
  572. fields_prefix (str): Prefix for the fields section.
  573. documentation_with_field_description (bool): Include field descriptions in the documentation.
  574. Returns:
  575. str: Generated text documentation.
  576. """
  577. documentation = ""
  578. pyd_models = [(model, True) for model in pydantic_models]
  579. for model, add_prefix in pyd_models:
  580. if add_prefix:
  581. documentation += f"{model_prefix}: {model.__name__}\n"
  582. else:
  583. documentation += f"Model: {model.__name__}\n"
  584. # Handling multi-line model description with proper indentation
  585. class_doc = getdoc(model)
  586. base_class_doc = getdoc(BaseModel)
  587. class_description = class_doc if class_doc and class_doc != base_class_doc else ""
  588. if class_description != "":
  589. documentation += " Description: "
  590. documentation += format_multiline_description(class_description, 0) + "\n"
  591. if add_prefix:
  592. # Indenting the fields section
  593. documentation += f" {fields_prefix}:\n"
  594. else:
  595. documentation += f" Fields:\n"
  596. if isclass(model) and issubclass(model, BaseModel):
  597. for name, field_type in model.__annotations__.items():
  598. # if name == "markdown_code_block":
  599. # continue
  600. if get_origin(field_type) == list:
  601. element_type = get_args(field_type)[0]
  602. if isclass(element_type) and issubclass(element_type, BaseModel):
  603. pyd_models.append((element_type, False))
  604. if get_origin(field_type) == Union:
  605. element_types = get_args(field_type)
  606. for element_type in element_types:
  607. if isclass(element_type) and issubclass(element_type, BaseModel):
  608. pyd_models.append((element_type, False))
  609. documentation += generate_field_markdown(
  610. name, field_type, model, documentation_with_field_description=documentation_with_field_description
  611. )
  612. documentation += "\n"
  613. if hasattr(model, "Config") and hasattr(model.Config,
  614. "json_schema_extra") and "example" in model.Config.json_schema_extra:
  615. documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n"
  616. json_example = json.dumps(model.Config.json_schema_extra["example"])
  617. documentation += format_multiline_description(json_example, 2) + "\n"
  618. return documentation
  619. def generate_field_markdown(
  620. field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
  621. documentation_with_field_description=True
  622. ) -> str:
  623. """
  624. Generate markdown documentation for a Pydantic model field.
  625. Args:
  626. field_name (str): Name of the field.
  627. field_type (Type[Any]): Type of the field.
  628. model (Type[BaseModel]): Pydantic model class.
  629. depth (int): Indentation depth in the documentation.
  630. documentation_with_field_description (bool): Include field descriptions in the documentation.
  631. Returns:
  632. str: Generated text documentation for the field.
  633. """
  634. indent = " " * depth
  635. field_info = model.model_fields.get(field_name)
  636. field_description = field_info.description if field_info and field_info.description else ""
  637. if get_origin(field_type) == list:
  638. element_type = get_args(field_type)[0]
  639. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
  640. if field_description != "":
  641. field_text += ":\n"
  642. else:
  643. field_text += "\n"
  644. elif get_origin(field_type) == Union:
  645. element_types = get_args(field_type)
  646. types = []
  647. for element_type in element_types:
  648. types.append(format_model_and_field_name(element_type.__name__))
  649. field_text = f"{indent}{field_name} ({' or '.join(types)})"
  650. if field_description != "":
  651. field_text += ":\n"
  652. else:
  653. field_text += "\n"
  654. else:
  655. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})"
  656. if field_description != "":
  657. field_text += ":\n"
  658. else:
  659. field_text += "\n"
  660. if not documentation_with_field_description:
  661. return field_text
  662. if field_description != "":
  663. field_text += f" Description: " + field_description + "\n"
  664. # Check for and include field-specific examples if available
  665. if hasattr(model, "Config") and hasattr(model.Config,
  666. "json_schema_extra") and "example" in model.Config.json_schema_extra:
  667. field_example = model.Config.json_schema_extra["example"].get(field_name)
  668. if field_example is not None:
  669. example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
  670. field_text += f"{indent} Example: {example_text}\n"
  671. if isclass(field_type) and issubclass(field_type, BaseModel):
  672. field_text += f"{indent} Details:\n"
  673. for name, type_ in field_type.__annotations__.items():
  674. field_text += generate_field_markdown(name, type_, field_type, depth + 2)
  675. return field_text
  676. def format_json_example(example: dict, depth: int) -> str:
  677. """
  678. Format a JSON example into a readable string with indentation.
  679. Args:
  680. example (dict): JSON example to be formatted.
  681. depth (int): Indentation depth.
  682. Returns:
  683. str: Formatted JSON example string.
  684. """
  685. indent = " " * depth
  686. formatted_example = "{\n"
  687. for key, value in example.items():
  688. value_text = f"'{value}'" if isinstance(value, str) else value
  689. formatted_example += f"{indent}{key}: {value_text},\n"
  690. formatted_example = formatted_example.rstrip(",\n") + "\n" + indent + "}"
  691. return formatted_example
  692. def generate_text_documentation(
  693. pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
  694. documentation_with_field_description=True
  695. ) -> str:
  696. """
  697. Generate text documentation for a list of Pydantic models.
  698. Args:
  699. pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
  700. model_prefix (str): Prefix for the model section.
  701. fields_prefix (str): Prefix for the fields section.
  702. documentation_with_field_description (bool): Include field descriptions in the documentation.
  703. Returns:
  704. str: Generated text documentation.
  705. """
  706. documentation = ""
  707. pyd_models = [(model, True) for model in pydantic_models]
  708. for model, add_prefix in pyd_models:
  709. if add_prefix:
  710. documentation += f"{model_prefix}: {model.__name__}\n"
  711. else:
  712. documentation += f"Model: {model.__name__}\n"
  713. # Handling multi-line model description with proper indentation
  714. class_doc = getdoc(model)
  715. base_class_doc = getdoc(BaseModel)
  716. class_description = class_doc if class_doc and class_doc != base_class_doc else ""
  717. if class_description != "":
  718. documentation += " Description: "
  719. documentation += "\n" + format_multiline_description(class_description, 2) + "\n"
  720. if isclass(model) and issubclass(model, BaseModel):
  721. documentation_fields = ""
  722. for name, field_type in model.__annotations__.items():
  723. # if name == "markdown_code_block":
  724. # continue
  725. if get_origin(field_type) == list:
  726. element_type = get_args(field_type)[0]
  727. if isclass(element_type) and issubclass(element_type, BaseModel):
  728. pyd_models.append((element_type, False))
  729. if get_origin(field_type) == Union:
  730. element_types = get_args(field_type)
  731. for element_type in element_types:
  732. if isclass(element_type) and issubclass(element_type, BaseModel):
  733. pyd_models.append((element_type, False))
  734. documentation_fields += generate_field_text(
  735. name, field_type, model, documentation_with_field_description=documentation_with_field_description
  736. )
  737. if documentation_fields != "":
  738. if add_prefix:
  739. documentation += f" {fields_prefix}:\n{documentation_fields}"
  740. else:
  741. documentation += f" Fields:\n{documentation_fields}"
  742. documentation += "\n"
  743. if hasattr(model, "Config") and hasattr(model.Config,
  744. "json_schema_extra") and "example" in model.Config.json_schema_extra:
  745. documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n"
  746. json_example = json.dumps(model.Config.json_schema_extra["example"])
  747. documentation += format_multiline_description(json_example, 2) + "\n"
  748. return documentation
  749. def generate_field_text(
  750. field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
  751. documentation_with_field_description=True
  752. ) -> str:
  753. """
  754. Generate text documentation for a Pydantic model field.
  755. Args:
  756. field_name (str): Name of the field.
  757. field_type (Type[Any]): Type of the field.
  758. model (Type[BaseModel]): Pydantic model class.
  759. depth (int): Indentation depth in the documentation.
  760. documentation_with_field_description (bool): Include field descriptions in the documentation.
  761. Returns:
  762. str: Generated text documentation for the field.
  763. """
  764. indent = " " * depth
  765. field_info = model.model_fields.get(field_name)
  766. field_description = field_info.description if field_info and field_info.description else ""
  767. if get_origin(field_type) == list:
  768. element_type = get_args(field_type)[0]
  769. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
  770. if field_description != "":
  771. field_text += ":\n"
  772. else:
  773. field_text += "\n"
  774. elif get_origin(field_type) == Union:
  775. element_types = get_args(field_type)
  776. types = []
  777. for element_type in element_types:
  778. types.append(format_model_and_field_name(element_type.__name__))
  779. field_text = f"{indent}{field_name} ({' or '.join(types)})"
  780. if field_description != "":
  781. field_text += ":\n"
  782. else:
  783. field_text += "\n"
  784. else:
  785. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})"
  786. if field_description != "":
  787. field_text += ":\n"
  788. else:
  789. field_text += "\n"
  790. if not documentation_with_field_description:
  791. return field_text
  792. if field_description != "":
  793. field_text += f"{indent} Description: " + field_description + "\n"
  794. # Check for and include field-specific examples if available
  795. if hasattr(model, "Config") and hasattr(model.Config,
  796. "json_schema_extra") and "example" in model.Config.json_schema_extra:
  797. field_example = model.Config.json_schema_extra["example"].get(field_name)
  798. if field_example is not None:
  799. example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
  800. field_text += f"{indent} Example: {example_text}\n"
  801. if isclass(field_type) and issubclass(field_type, BaseModel):
  802. field_text += f"{indent} Details:\n"
  803. for name, type_ in field_type.__annotations__.items():
  804. field_text += generate_field_text(name, type_, field_type, depth + 2)
  805. return field_text
  806. def format_multiline_description(description: str, indent_level: int) -> str:
  807. """
  808. Format a multiline description with proper indentation.
  809. Args:
  810. description (str): Multiline description.
  811. indent_level (int): Indentation level.
  812. Returns:
  813. str: Formatted multiline description.
  814. """
  815. indent = " " * indent_level
  816. return indent + description.replace("\n", "\n" + indent)
  817. def save_gbnf_grammar_and_documentation(
  818. grammar, documentation, grammar_file_path="./grammar.gbnf", documentation_file_path="./grammar_documentation.md"
  819. ):
  820. """
  821. Save GBNF grammar and documentation to specified files.
  822. Args:
  823. grammar (str): GBNF grammar string.
  824. documentation (str): Documentation string.
  825. grammar_file_path (str): File path to save the GBNF grammar.
  826. documentation_file_path (str): File path to save the documentation.
  827. Returns:
  828. None
  829. """
  830. try:
  831. with open(grammar_file_path, "w") as file:
  832. file.write(grammar + get_primitive_grammar(grammar))
  833. print(f"Grammar successfully saved to {grammar_file_path}")
  834. except IOError as e:
  835. print(f"An error occurred while saving the grammar file: {e}")
  836. try:
  837. with open(documentation_file_path, "w") as file:
  838. file.write(documentation)
  839. print(f"Documentation successfully saved to {documentation_file_path}")
  840. except IOError as e:
  841. print(f"An error occurred while saving the documentation file: {e}")
  842. def remove_empty_lines(string):
  843. """
  844. Remove empty lines from a string.
  845. Args:
  846. string (str): Input string.
  847. Returns:
  848. str: String with empty lines removed.
  849. """
  850. lines = string.splitlines()
  851. non_empty_lines = [line for line in lines if line.strip() != ""]
  852. string_no_empty_lines = "\n".join(non_empty_lines)
  853. return string_no_empty_lines
  854. def generate_and_save_gbnf_grammar_and_documentation(
  855. pydantic_model_list,
  856. grammar_file_path="./generated_grammar.gbnf",
  857. documentation_file_path="./generated_grammar_documentation.md",
  858. outer_object_name: str = None,
  859. outer_object_content: str = None,
  860. model_prefix: str = "Output Model",
  861. fields_prefix: str = "Output Fields",
  862. list_of_outputs: bool = False,
  863. documentation_with_field_description=True,
  864. ):
  865. """
  866. Generate GBNF grammar and documentation, and save them to specified files.
  867. Args:
  868. pydantic_model_list: List of Pydantic model classes.
  869. grammar_file_path (str): File path to save the generated GBNF grammar.
  870. documentation_file_path (str): File path to save the generated documentation.
  871. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  872. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  873. model_prefix (str): Prefix for the model section in the documentation.
  874. fields_prefix (str): Prefix for the fields section in the documentation.
  875. list_of_outputs (bool): Whether the output is a list of items.
  876. documentation_with_field_description (bool): Include field descriptions in the documentation.
  877. Returns:
  878. None
  879. """
  880. documentation = generate_markdown_documentation(
  881. pydantic_model_list, model_prefix, fields_prefix,
  882. documentation_with_field_description=documentation_with_field_description
  883. )
  884. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content,
  885. list_of_outputs)
  886. grammar = remove_empty_lines(grammar)
  887. save_gbnf_grammar_and_documentation(grammar, documentation, grammar_file_path, documentation_file_path)
  888. def generate_gbnf_grammar_and_documentation(
  889. pydantic_model_list,
  890. outer_object_name: str = None,
  891. outer_object_content: str = None,
  892. model_prefix: str = "Output Model",
  893. fields_prefix: str = "Output Fields",
  894. list_of_outputs: bool = False,
  895. documentation_with_field_description=True,
  896. ):
  897. """
  898. Generate GBNF grammar and documentation for a list of Pydantic models.
  899. Args:
  900. pydantic_model_list: List of Pydantic model classes.
  901. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  902. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  903. model_prefix (str): Prefix for the model section in the documentation.
  904. fields_prefix (str): Prefix for the fields section in the documentation.
  905. list_of_outputs (bool): Whether the output is a list of items.
  906. documentation_with_field_description (bool): Include field descriptions in the documentation.
  907. Returns:
  908. tuple: GBNF grammar string, documentation string.
  909. """
  910. documentation = generate_markdown_documentation(
  911. copy(pydantic_model_list), model_prefix, fields_prefix,
  912. documentation_with_field_description=documentation_with_field_description
  913. )
  914. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content,
  915. list_of_outputs)
  916. grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar))
  917. return grammar, documentation
  918. def generate_gbnf_grammar_and_documentation_from_dictionaries(
  919. dictionaries: List[dict],
  920. outer_object_name: str = None,
  921. outer_object_content: str = None,
  922. model_prefix: str = "Output Model",
  923. fields_prefix: str = "Output Fields",
  924. list_of_outputs: bool = False,
  925. documentation_with_field_description=True,
  926. ):
  927. """
  928. Generate GBNF grammar and documentation from a list of dictionaries.
  929. Args:
  930. dictionaries (List[dict]): List of dictionaries representing Pydantic models.
  931. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  932. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  933. model_prefix (str): Prefix for the model section in the documentation.
  934. fields_prefix (str): Prefix for the fields section in the documentation.
  935. list_of_outputs (bool): Whether the output is a list of items.
  936. documentation_with_field_description (bool): Include field descriptions in the documentation.
  937. Returns:
  938. tuple: GBNF grammar string, documentation string.
  939. """
  940. pydantic_model_list = create_dynamic_models_from_dictionaries(dictionaries)
  941. documentation = generate_markdown_documentation(
  942. copy(pydantic_model_list), model_prefix, fields_prefix,
  943. documentation_with_field_description=documentation_with_field_description
  944. )
  945. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content,
  946. list_of_outputs)
  947. grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar))
  948. return grammar, documentation
  949. def create_dynamic_model_from_function(func: Callable):
  950. """
  951. Creates a dynamic Pydantic model from a given function's type hints and adds the function as a 'run' method.
  952. Args:
  953. func (Callable): A function with type hints from which to create the model.
  954. Returns:
  955. A dynamic Pydantic model class with the provided function as a 'run' method.
  956. """
  957. # Get the signature of the function
  958. sig = inspect.signature(func)
  959. # Parse the docstring
  960. docstring = parse(func.__doc__)
  961. dynamic_fields = {}
  962. param_docs = []
  963. for param in sig.parameters.values():
  964. # Exclude 'self' parameter
  965. if param.name == "self":
  966. continue
  967. # Assert that the parameter has a type annotation
  968. if param.annotation == inspect.Parameter.empty:
  969. raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation")
  970. # Find the parameter's description in the docstring
  971. param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
  972. # Assert that the parameter has a description
  973. if not param_doc or not param_doc.description:
  974. raise ValueError(
  975. f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
  976. # Add parameter details to the schema
  977. param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
  978. param_docs.append((param.name, param_doc))
  979. if param.default == inspect.Parameter.empty:
  980. default_value = ...
  981. else:
  982. default_value = param.default
  983. dynamic_fields[param.name] = (
  984. param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
  985. # Creating the dynamic model
  986. dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
  987. for param_doc in param_docs:
  988. dynamic_model.model_fields[param_doc[0]].description = param_doc[1].description
  989. dynamic_model.__doc__ = docstring.short_description
  990. def run_method_wrapper(self):
  991. func_args = {name: getattr(self, name) for name, _ in dynamic_fields.items()}
  992. return func(**func_args)
  993. # Adding the wrapped function as a 'run' method
  994. setattr(dynamic_model, "run", run_method_wrapper)
  995. return dynamic_model
  996. def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
  997. """
  998. Add a 'run' method to a dynamic Pydantic model, using the provided function.
  999. Args:
  1000. model (Type[BaseModel]): Dynamic Pydantic model class.
  1001. func (Callable): Function to be added as a 'run' method to the model.
  1002. Returns:
  1003. Type[BaseModel]: Pydantic model class with the added 'run' method.
  1004. """
  1005. def run_method_wrapper(self):
  1006. func_args = {name: getattr(self, name) for name in model.model_fields}
  1007. return func(**func_args)
  1008. # Adding the wrapped function as a 'run' method
  1009. setattr(model, "run", run_method_wrapper)
  1010. return model
  1011. def create_dynamic_models_from_dictionaries(dictionaries: List[dict]):
  1012. """
  1013. Create a list of dynamic Pydantic model classes from a list of dictionaries.
  1014. Args:
  1015. dictionaries (List[dict]): List of dictionaries representing model structures.
  1016. Returns:
  1017. List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
  1018. """
  1019. dynamic_models = []
  1020. for func in dictionaries:
  1021. model_name = format_model_and_field_name(func.get("name", ""))
  1022. dyn_model = convert_dictionary_to_pydantic_model(func, model_name)
  1023. dynamic_models.append(dyn_model)
  1024. return dynamic_models
  1025. def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
  1026. output = {}
  1027. for model in pydantic_model_list:
  1028. output[format_model_and_field_name(model.__name__)] = model
  1029. return output
  1030. from enum import Enum
  1031. def json_schema_to_python_types(schema):
  1032. type_map = {
  1033. "any": Any,
  1034. "string": str,
  1035. "number": float,
  1036. "integer": int,
  1037. "boolean": bool,
  1038. "array": list,
  1039. }
  1040. return type_map[schema]
  1041. def list_to_enum(enum_name, values):
  1042. return Enum(enum_name, {value: value for value in values})
  1043. def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "CustomModel") -> Type[BaseModel]:
  1044. """
  1045. Convert a dictionary to a Pydantic model class.
  1046. Args:
  1047. dictionary (dict): Dictionary representing the model structure.
  1048. model_name (str): Name of the generated Pydantic model.
  1049. Returns:
  1050. Type[BaseModel]: Generated Pydantic model class.
  1051. """
  1052. fields = {}
  1053. if "properties" in dictionary:
  1054. for field_name, field_data in dictionary.get("properties", {}).items():
  1055. if field_data == "object":
  1056. submodel = convert_dictionary_to_pydantic_model(dictionary, f"{model_name}_{field_name}")
  1057. fields[field_name] = (submodel, ...)
  1058. else:
  1059. field_type = field_data.get("type", "str")
  1060. if field_data.get("enum", []):
  1061. fields[field_name] = (list_to_enum(field_name, field_data.get("enum", [])), ...)
  1062. elif field_type == "array":
  1063. items = field_data.get("items", {})
  1064. if items != {}:
  1065. array = {"properties": items}
  1066. array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
  1067. fields[field_name] = (List[array_type], ...)
  1068. else:
  1069. fields[field_name] = (list, ...)
  1070. elif field_type == "object":
  1071. submodel = convert_dictionary_to_pydantic_model(field_data, f"{model_name}_{field_name}")
  1072. fields[field_name] = (submodel, ...)
  1073. elif field_type == "required":
  1074. required = field_data.get("enum", [])
  1075. for key, field in fields.items():
  1076. if key not in required:
  1077. fields[key] = (Optional[fields[key][0]], ...)
  1078. else:
  1079. field_type = json_schema_to_python_types(field_type)
  1080. fields[field_name] = (field_type, ...)
  1081. if "function" in dictionary:
  1082. for field_name, field_data in dictionary.get("function", {}).items():
  1083. if field_name == "name":
  1084. model_name = field_data
  1085. elif field_name == "description":
  1086. fields["__doc__"] = field_data
  1087. elif field_name == "parameters":
  1088. return convert_dictionary_to_pydantic_model(field_data, f"{model_name}")
  1089. if "parameters" in dictionary:
  1090. field_data = {"function": dictionary}
  1091. return convert_dictionary_to_pydantic_model(field_data, f"{model_name}")
  1092. if "required" in dictionary:
  1093. required = dictionary.get("required", [])
  1094. for key, field in fields.items():
  1095. if key not in required:
  1096. fields[key] = (Optional[fields[key][0]], ...)
  1097. custom_model = create_model(model_name, **fields)
  1098. return custom_model