pydantic_models_to_grammar.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151
  1. import inspect
  2. import json
  3. from copy import copy
  4. from inspect import isclass, getdoc
  5. from types import NoneType
  6. from pydantic import BaseModel, create_model, Field
  7. from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional, _GenericAlias
  8. from enum import Enum
  9. from typing import get_type_hints, Callable
  10. import re
  11. class PydanticDataType(Enum):
  12. """
  13. Defines the data types supported by the grammar_generator.
  14. Attributes:
  15. STRING (str): Represents a string data type.
  16. BOOLEAN (str): Represents a boolean data type.
  17. INTEGER (str): Represents an integer data type.
  18. FLOAT (str): Represents a float data type.
  19. OBJECT (str): Represents an object data type.
  20. ARRAY (str): Represents an array data type.
  21. ENUM (str): Represents an enum data type.
  22. CUSTOM_CLASS (str): Represents a custom class data type.
  23. """
  24. STRING = "string"
  25. TRIPLE_QUOTED_STRING = "triple_quoted_string"
  26. MARKDOWN_STRING = "markdown_string"
  27. BOOLEAN = "boolean"
  28. INTEGER = "integer"
  29. FLOAT = "float"
  30. OBJECT = "object"
  31. ARRAY = "array"
  32. ENUM = "enum"
  33. ANY = "any"
  34. NULL = "null"
  35. CUSTOM_CLASS = "custom-class"
  36. CUSTOM_DICT = "custom-dict"
  37. SET = "set"
  38. def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
  39. if isclass(pydantic_type) and issubclass(pydantic_type, str):
  40. return PydanticDataType.STRING.value
  41. elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
  42. return PydanticDataType.BOOLEAN.value
  43. elif isclass(pydantic_type) and issubclass(pydantic_type, int):
  44. return PydanticDataType.INTEGER.value
  45. elif isclass(pydantic_type) and issubclass(pydantic_type, float):
  46. return PydanticDataType.FLOAT.value
  47. elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
  48. return PydanticDataType.ENUM.value
  49. elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
  50. return format_model_and_field_name(pydantic_type.__name__)
  51. elif get_origin(pydantic_type) == list:
  52. element_type = get_args(pydantic_type)[0]
  53. return f"{map_pydantic_type_to_gbnf(element_type)}-list"
  54. elif get_origin(pydantic_type) == set:
  55. element_type = get_args(pydantic_type)[0]
  56. return f"{map_pydantic_type_to_gbnf(element_type)}-set"
  57. elif get_origin(pydantic_type) == Union:
  58. union_types = get_args(pydantic_type)
  59. union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
  60. return f"union-{'-or-'.join(union_rules)}"
  61. elif get_origin(pydantic_type) == Optional:
  62. element_type = get_args(pydantic_type)[0]
  63. return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
  64. elif isclass(pydantic_type):
  65. return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
  66. elif get_origin(pydantic_type) == dict:
  67. key_type, value_type = get_args(pydantic_type)
  68. return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
  69. else:
  70. return "unknown"
  71. def format_model_and_field_name(model_name: str) -> str:
  72. parts = re.findall('[A-Z][^A-Z]*', model_name)
  73. if not parts: # Check if the list is empty
  74. return model_name.lower().replace("_", "-")
  75. return '-'.join(part.lower().replace("_", "-") for part in parts)
  76. def generate_list_rule(element_type):
  77. """
  78. Generate a GBNF rule for a list of a given element type.
  79. :param element_type: The type of the elements in the list (e.g., 'string').
  80. :return: A string representing the GBNF rule for a list of the given type.
  81. """
  82. rule_name = f"{map_pydantic_type_to_gbnf(element_type)}-list"
  83. element_rule = map_pydantic_type_to_gbnf(element_type)
  84. list_rule = fr'{rule_name} ::= "[" {element_rule} ("," {element_rule})* "]"'
  85. return list_rule
  86. def get_members_structure(cls, rule_name):
  87. if issubclass(cls, Enum):
  88. # Handle Enum types
  89. members = [f'\"\\\"{member.value}\\\"\"' for name, member in cls.__members__.items()]
  90. return f"{cls.__name__.lower()} ::= " + " | ".join(members)
  91. if cls.__annotations__ and cls.__annotations__ != {}:
  92. result = f'{rule_name} ::= "{{"'
  93. type_list_rules = []
  94. # Modify this comprehension
  95. members = [f' \"\\\"{name}\\\"\" ":" {map_pydantic_type_to_gbnf(param_type)}'
  96. for name, param_type in cls.__annotations__.items()
  97. if name != 'self']
  98. result += '"," '.join(members)
  99. result += ' "}"'
  100. return result, type_list_rules
  101. elif rule_name == "custom-class-any":
  102. result = f'{rule_name} ::= '
  103. result += 'value'
  104. type_list_rules = []
  105. return result, type_list_rules
  106. else:
  107. init_signature = inspect.signature(cls.__init__)
  108. parameters = init_signature.parameters
  109. result = f'{rule_name} ::= "{{"'
  110. type_list_rules = []
  111. # Modify this comprehension too
  112. members = [f' \"\\\"{name}\\\"\" ":" {map_pydantic_type_to_gbnf(param.annotation)}'
  113. for name, param in parameters.items()
  114. if name != 'self' and param.annotation != inspect.Parameter.empty]
  115. result += '", "'.join(members)
  116. result += ' "}"'
  117. return result, type_list_rules
  118. def regex_to_gbnf(regex_pattern: str) -> str:
  119. """
  120. Translate a basic regex pattern to a GBNF rule.
  121. Note: This function handles only a subset of simple regex patterns.
  122. """
  123. gbnf_rule = regex_pattern
  124. # Translate common regex components to GBNF
  125. gbnf_rule = gbnf_rule.replace('\\d', '[0-9]')
  126. gbnf_rule = gbnf_rule.replace('\\s', '[ \t\n]')
  127. # Handle quantifiers and other regex syntax that is similar in GBNF
  128. # (e.g., '*', '+', '?', character classes)
  129. return gbnf_rule
  130. def generate_gbnf_integer_rules(max_digit=None, min_digit=None):
  131. """
  132. Generate GBNF Integer Rules
  133. Generates GBNF (Generalized Backus-Naur Form) rules for integers based on the given maximum and minimum digits.
  134. Parameters:
  135. max_digit (int): The maximum number of digits for the integer. Default is None.
  136. min_digit (int): The minimum number of digits for the integer. Default is None.
  137. Returns:
  138. integer_rule (str): The identifier for the integer rule generated.
  139. additional_rules (list): A list of additional rules generated based on the given maximum and minimum digits.
  140. """
  141. additional_rules = []
  142. # Define the rule identifier based on max_digit and min_digit
  143. integer_rule = "integer-part"
  144. if max_digit is not None:
  145. integer_rule += f"-max{max_digit}"
  146. if min_digit is not None:
  147. integer_rule += f"-min{min_digit}"
  148. # Handling Integer Rules
  149. if max_digit is not None or min_digit is not None:
  150. # Start with an empty rule part
  151. integer_rule_part = ''
  152. # Add mandatory digits as per min_digit
  153. if min_digit is not None:
  154. integer_rule_part += '[0-9] ' * min_digit
  155. # Add optional digits up to max_digit
  156. if max_digit is not None:
  157. optional_digits = max_digit - (min_digit if min_digit is not None else 0)
  158. integer_rule_part += ''.join(['[0-9]? ' for _ in range(optional_digits)])
  159. # Trim the rule part and append it to additional rules
  160. integer_rule_part = integer_rule_part.strip()
  161. if integer_rule_part:
  162. additional_rules.append(f'{integer_rule} ::= {integer_rule_part}')
  163. return integer_rule, additional_rules
  164. def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None, min_precision=None):
  165. """
  166. Generate GBNF float rules based on the given constraints.
  167. :param max_digit: Maximum number of digits in the integer part (default: None)
  168. :param min_digit: Minimum number of digits in the integer part (default: None)
  169. :param max_precision: Maximum number of digits in the fractional part (default: None)
  170. :param min_precision: Minimum number of digits in the fractional part (default: None)
  171. :return: A tuple containing the float rule and additional rules as a list
  172. Example Usage:
  173. max_digit = 3
  174. min_digit = 1
  175. max_precision = 2
  176. min_precision = 1
  177. generate_gbnf_float_rules(max_digit, min_digit, max_precision, min_precision)
  178. Output:
  179. ('float-3-1-2-1', ['integer-part-max3-min1 ::= [0-9] [0-9] [0-9]?', 'fractional-part-max2-min1 ::= [0-9] [0-9]?', 'float-3-1-2-1 ::= integer-part-max3-min1 "." fractional-part-max2-min
  180. *1'])
  181. Note:
  182. GBNF stands for Generalized Backus-Naur Form, which is a notation technique to specify the syntax of programming languages or other formal grammars.
  183. """
  184. additional_rules = []
  185. # Define the integer part rule
  186. integer_part_rule = "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
  187. f"-min{min_digit}" if min_digit is not None else "")
  188. # Define the fractional part rule based on precision constraints
  189. fractional_part_rule = "fractional-part"
  190. fractional_rule_part = ''
  191. if max_precision is not None or min_precision is not None:
  192. fractional_part_rule += (f"-max{max_precision}" if max_precision is not None else "") + (
  193. f"-min{min_precision}" if min_precision is not None else "")
  194. # Minimum number of digits
  195. fractional_rule_part = '[0-9]' * (min_precision if min_precision is not None else 1)
  196. # Optional additional digits
  197. fractional_rule_part += ''.join([' [0-9]?'] * (
  198. (max_precision - (min_precision if min_precision is not None else 1)) if max_precision is not None else 0))
  199. additional_rules.append(f'{fractional_part_rule} ::= {fractional_rule_part}')
  200. # Define the float rule
  201. float_rule = f"float-{max_digit if max_digit is not None else 'X'}-{min_digit if min_digit is not None else 'X'}-{max_precision if max_precision is not None else 'X'}-{min_precision if min_precision is not None else 'X'}"
  202. additional_rules.append(f'{float_rule} ::= {integer_part_rule} "." {fractional_part_rule}')
  203. # Generating the integer part rule definition, if necessary
  204. if max_digit is not None or min_digit is not None:
  205. integer_rule_part = '[0-9]'
  206. if min_digit is not None and min_digit > 1:
  207. integer_rule_part += ' [0-9]' * (min_digit - 1)
  208. if max_digit is not None:
  209. integer_rule_part += ''.join([' [0-9]?'] * (max_digit - (min_digit if min_digit is not None else 1)))
  210. additional_rules.append(f'{integer_part_rule} ::= {integer_rule_part.strip()}')
  211. return float_rule, additional_rules
  212. def generate_gbnf_rule_for_type(model_name, field_name,
  213. field_type, is_optional, processed_models, created_rules,
  214. field_info=None) -> \
  215. Tuple[str, list]:
  216. """
  217. Generate GBNF rule for a given field type.
  218. :param model_name: Name of the model.
  219. :param field_name: Name of the field.
  220. :param field_type: Type of the field.
  221. :param is_optional: Whether the field is optional.
  222. :param processed_models: List of processed models.
  223. :param created_rules: List of created rules.
  224. :param field_info: Additional information about the field (optional).
  225. :return: Tuple containing the GBNF type and a list of additional rules.
  226. :rtype: Tuple[str, list]
  227. """
  228. rules = []
  229. field_name = format_model_and_field_name(field_name)
  230. gbnf_type = map_pydantic_type_to_gbnf(field_type)
  231. if isclass(field_type) and issubclass(field_type, BaseModel):
  232. nested_model_name = format_model_and_field_name(field_type.__name__)
  233. nested_model_rules = generate_gbnf_grammar(field_type, processed_models, created_rules)
  234. rules.extend(nested_model_rules)
  235. gbnf_type, rules = nested_model_name, rules
  236. elif isclass(field_type) and issubclass(field_type, Enum):
  237. enum_values = [f'\"\\\"{e.value}\\\"\"' for e in field_type] # Adding escaped quotes
  238. enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
  239. rules.append(enum_rule)
  240. gbnf_type, rules = model_name + "-" + field_name, rules
  241. elif get_origin(field_type) == list or field_type == list: # Array
  242. element_type = get_args(field_type)[0]
  243. element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
  244. f"{field_name}-element",
  245. element_type, is_optional, processed_models,
  246. created_rules)
  247. rules.extend(additional_rules)
  248. array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """
  249. rules.append(array_rule)
  250. gbnf_type, rules = model_name + "-" + field_name, rules
  251. elif get_origin(field_type) == set or field_type == set: # Array
  252. element_type = get_args(field_type)[0]
  253. element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
  254. f"{field_name}-element",
  255. element_type, is_optional, processed_models,
  256. created_rules)
  257. rules.extend(additional_rules)
  258. array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """
  259. rules.append(array_rule)
  260. gbnf_type, rules = model_name + "-" + field_name, rules
  261. elif gbnf_type.startswith("custom-class-"):
  262. nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
  263. rules.append(nested_model_rules)
  264. elif gbnf_type.startswith("custom-dict-"):
  265. key_type, value_type = get_args(field_type)
  266. additional_key_type, additional_key_rules = generate_gbnf_rule_for_type(model_name,
  267. f"{field_name}-key-type",
  268. key_type, is_optional, processed_models,
  269. created_rules)
  270. additional_value_type, additional_value_rules = generate_gbnf_rule_for_type(model_name,
  271. f"{field_name}-value-type",
  272. value_type, is_optional,
  273. processed_models, created_rules)
  274. gbnf_type = fr'{gbnf_type} ::= "{{" ( {additional_key_type} ":" {additional_value_type} ("," {additional_key_type} ":" {additional_value_type})* )? "}}" '
  275. rules.extend(additional_key_rules)
  276. rules.extend(additional_value_rules)
  277. elif gbnf_type.startswith("union-"):
  278. union_types = get_args(field_type)
  279. union_rules = []
  280. for union_type in union_types:
  281. if isinstance(union_type, _GenericAlias):
  282. union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name,
  283. field_name, union_type,
  284. False,
  285. processed_models, created_rules)
  286. union_rules.append(union_gbnf_type)
  287. rules.extend(union_rules_list)
  288. elif not issubclass(union_type, NoneType):
  289. union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name,
  290. field_name, union_type,
  291. False,
  292. processed_models, created_rules)
  293. union_rules.append(union_gbnf_type)
  294. rules.extend(union_rules_list)
  295. # Defining the union grammar rule separately
  296. if len(union_rules) == 1:
  297. union_grammar_rule = f"{model_name}-{field_name}-optional ::= {' | '.join(union_rules)} | null"
  298. else:
  299. union_grammar_rule = f"{model_name}-{field_name}-union ::= {' | '.join(union_rules)}"
  300. rules.append(union_grammar_rule)
  301. if len(union_rules) == 1:
  302. gbnf_type = f"{model_name}-{field_name}-optional"
  303. else:
  304. gbnf_type = f"{model_name}-{field_name}-union"
  305. elif isclass(field_type) and issubclass(field_type, str):
  306. if field_info and hasattr(field_info, 'json_schema_extra') and field_info.json_schema_extra is not None:
  307. triple_quoted_string = field_info.json_schema_extra.get('triple_quoted_string', False)
  308. markdown_string = field_info.json_schema_extra.get('markdown_string', False)
  309. gbnf_type = PydanticDataType.TRIPLE_QUOTED_STRING.value if triple_quoted_string else PydanticDataType.STRING.value
  310. gbnf_type = PydanticDataType.MARKDOWN_STRING.value if markdown_string else gbnf_type
  311. elif field_info and hasattr(field_info, 'pattern'):
  312. # Convert regex pattern to grammar rule
  313. regex_pattern = field_info.regex.pattern
  314. gbnf_type = f"pattern-{field_name} ::= {regex_to_gbnf(regex_pattern)}"
  315. else:
  316. gbnf_type = PydanticDataType.STRING.value
  317. elif isclass(field_type) and issubclass(field_type, float) and field_info and hasattr(field_info,
  318. 'json_schema_extra') and field_info.json_schema_extra is not None:
  319. # Retrieve precision attributes for floats
  320. max_precision = field_info.json_schema_extra.get('max_precision') if field_info and hasattr(field_info,
  321. 'json_schema_extra') else None
  322. min_precision = field_info.json_schema_extra.get('min_precision') if field_info and hasattr(field_info,
  323. 'json_schema_extra') else None
  324. max_digits = field_info.json_schema_extra.get('max_digit') if field_info and hasattr(field_info,
  325. 'json_schema_extra') else None
  326. min_digits = field_info.json_schema_extra.get('min_digit') if field_info and hasattr(field_info,
  327. 'json_schema_extra') else None
  328. # Generate GBNF rule for float with given attributes
  329. gbnf_type, rules = generate_gbnf_float_rules(max_digit=max_digits, min_digit=min_digits,
  330. max_precision=max_precision,
  331. min_precision=min_precision)
  332. elif isclass(field_type) and issubclass(field_type, int) and field_info and hasattr(field_info,
  333. 'json_schema_extra') and field_info.json_schema_extra is not None:
  334. # Retrieve digit attributes for integers
  335. max_digits = field_info.json_schema_extra.get('max_digit') if field_info and hasattr(field_info,
  336. 'json_schema_extra') else None
  337. min_digits = field_info.json_schema_extra.get('min_digit') if field_info and hasattr(field_info,
  338. 'json_schema_extra') else None
  339. # Generate GBNF rule for integer with given attributes
  340. gbnf_type, rules = generate_gbnf_integer_rules(max_digit=max_digits, min_digit=min_digits)
  341. else:
  342. gbnf_type, rules = gbnf_type, []
  343. if gbnf_type not in created_rules:
  344. return gbnf_type, rules
  345. else:
  346. if gbnf_type in created_rules:
  347. return gbnf_type, rules
  348. def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created_rules: dict) -> (list, bool, bool):
  349. """
  350. Generate GBnF Grammar
  351. Generates a GBnF grammar for a given model.
  352. :param model: A Pydantic model class to generate the grammar for. Must be a subclass of BaseModel.
  353. :param processed_models: A set of already processed models to prevent infinite recursion.
  354. :param created_rules: A dict containing already created rules to prevent duplicates.
  355. :return: A list of GBnF grammar rules in string format. And two booleans indicating if an extra markdown or triple quoted string is in the grammar.
  356. Example Usage:
  357. ```
  358. model = MyModel
  359. processed_models = set()
  360. created_rules = dict()
  361. gbnf_grammar = generate_gbnf_grammar(model, processed_models, created_rules)
  362. ```
  363. """
  364. if model in processed_models:
  365. return []
  366. processed_models.add(model)
  367. model_name = format_model_and_field_name(model.__name__)
  368. if not issubclass(model, BaseModel):
  369. # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
  370. if hasattr(model, '__annotations__') and model.__annotations__:
  371. model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
  372. else:
  373. init_signature = inspect.signature(model.__init__)
  374. parameters = init_signature.parameters
  375. model_fields = {name: (param.annotation, param.default) for name, param in parameters.items()
  376. if name != 'self'}
  377. else:
  378. # For Pydantic models, use model_fields and check for ellipsis (required fields)
  379. model_fields = model.__annotations__
  380. model_rule_parts = []
  381. nested_rules = []
  382. has_markdown_code_block = False
  383. has_triple_quoted_string = False
  384. look_for_markdown_code_block = False
  385. look_for_triple_quoted_string = False
  386. for field_name, field_info in model_fields.items():
  387. if not issubclass(model, BaseModel):
  388. field_type, default_value = field_info
  389. # Check if the field is optional (not required)
  390. is_optional = (default_value is not inspect.Parameter.empty) and (default_value is not Ellipsis)
  391. else:
  392. field_type = field_info
  393. field_info = model.model_fields[field_name]
  394. is_optional = field_info.is_required is False and get_origin(field_type) is Optional
  395. rule_name, additional_rules = generate_gbnf_rule_for_type(model_name,
  396. format_model_and_field_name(field_name),
  397. field_type, is_optional,
  398. processed_models, created_rules, field_info)
  399. look_for_markdown_code_block = True if rule_name == "markdown_string" else False
  400. look_for_triple_quoted_string = True if rule_name == "triple_quoted_string" else False
  401. if not look_for_markdown_code_block and not look_for_triple_quoted_string:
  402. if rule_name not in created_rules:
  403. created_rules[rule_name] = additional_rules
  404. model_rule_parts.append(f' ws \"\\\"{field_name}\\\"\" ": " {rule_name}') # Adding escaped quotes
  405. nested_rules.extend(additional_rules)
  406. else:
  407. has_triple_quoted_string = look_for_markdown_code_block
  408. has_markdown_code_block = look_for_triple_quoted_string
  409. fields_joined = r' "," "\n" '.join(model_rule_parts)
  410. model_rule = fr'{model_name} ::= "{{" "\n" {fields_joined} "\n" ws "}}"'
  411. if look_for_markdown_code_block or look_for_triple_quoted_string:
  412. model_rule += ' ws "}"'
  413. if has_triple_quoted_string:
  414. model_rule += '"\\n" triple-quoted-string'
  415. if has_markdown_code_block:
  416. model_rule += '"\\n" markdown-code-block'
  417. all_rules = [model_rule] + nested_rules
  418. return all_rules, has_markdown_code_block, has_triple_quoted_string
  419. def generate_gbnf_grammar_from_pydantic_models(models: List[Type[BaseModel]], outer_object_name: str = None,
  420. outer_object_content: str = None, list_of_outputs: bool = False) -> str:
  421. """
  422. Generate GBNF Grammar from Pydantic Models.
  423. This method takes a list of Pydantic models and uses them to generate a GBNF grammar string. The generated grammar string can be used for parsing and validating data using the generated
  424. * grammar.
  425. Parameters:
  426. models (List[Type[BaseModel]]): A list of Pydantic models to generate the grammar from.
  427. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  428. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  429. list_of_outputs (str, optional): Allows a list of output objects
  430. Returns:
  431. str: The generated GBNF grammar string.
  432. Examples:
  433. models = [UserModel, PostModel]
  434. grammar = generate_gbnf_grammar_from_pydantic(models)
  435. print(grammar)
  436. # Output:
  437. # root ::= UserModel | PostModel
  438. # ...
  439. """
  440. processed_models = set()
  441. all_rules = []
  442. created_rules = {}
  443. if outer_object_name is None:
  444. for model in models:
  445. model_rules, _, _ = generate_gbnf_grammar(model,
  446. processed_models, created_rules)
  447. all_rules.extend(model_rules)
  448. if list_of_outputs:
  449. root_rule = r'root ::= ws "[" grammar-models ("," grammar-models)* "]"' + "\n"
  450. else:
  451. root_rule = r'root ::= ws grammar-models' + "\n"
  452. root_rule += "grammar-models ::= " + " | ".join(
  453. [format_model_and_field_name(model.__name__) for model in models])
  454. all_rules.insert(0, root_rule)
  455. return "\n".join(all_rules)
  456. elif outer_object_name is not None:
  457. if list_of_outputs:
  458. root_rule = fr'root ::= ws "[" {format_model_and_field_name(outer_object_name)} ("," {format_model_and_field_name(outer_object_name)})* "]"' + "\n"
  459. else:
  460. root_rule = f"root ::= {format_model_and_field_name(outer_object_name)}\n"
  461. model_rule = fr'{format_model_and_field_name(outer_object_name)} ::= ws "{{" ws "\"{outer_object_name}\"" ": " grammar-models'
  462. fields_joined = " | ".join(
  463. [fr'{format_model_and_field_name(model.__name__)}-grammar-model' for model in models])
  464. grammar_model_rules = f'\ngrammar-models ::= {fields_joined}'
  465. mod_rules = []
  466. for model in models:
  467. mod_rule = fr'{format_model_and_field_name(model.__name__)}-grammar-model ::= ws'
  468. mod_rule += fr'"\"{format_model_and_field_name(model.__name__)}\"" "," ws "\"{outer_object_content}\"" ws ":" ws {format_model_and_field_name(model.__name__)}' + '\n'
  469. mod_rules.append(mod_rule)
  470. grammar_model_rules += "\n" + "\n".join(mod_rules)
  471. look_for_markdown_code_block = False
  472. look_for_triple_quoted_string = False
  473. for model in models:
  474. model_rules, markdown_block, triple_quoted_string = generate_gbnf_grammar(model,
  475. processed_models, created_rules)
  476. all_rules.extend(model_rules)
  477. if markdown_block:
  478. look_for_markdown_code_block = True
  479. if triple_quoted_string:
  480. look_for_triple_quoted_string = True
  481. if not look_for_markdown_code_block and not look_for_triple_quoted_string:
  482. model_rule += ' ws "}"'
  483. all_rules.insert(0, root_rule + model_rule + grammar_model_rules)
  484. return "\n".join(all_rules)
  485. def get_primitive_grammar(grammar):
  486. """
  487. Returns the needed GBNF primitive grammar for a given GBNF grammar string.
  488. Args:
  489. grammar (str): The string containing the GBNF grammar.
  490. Returns:
  491. str: GBNF primitive grammar string.
  492. """
  493. type_list = []
  494. if "string-list" in grammar:
  495. type_list.append(str)
  496. if "boolean-list" in grammar:
  497. type_list.append(bool)
  498. if "integer-list" in grammar:
  499. type_list.append(int)
  500. if "float-list" in grammar:
  501. type_list.append(float)
  502. additional_grammar = [generate_list_rule(t) for t in type_list]
  503. primitive_grammar = r"""
  504. boolean ::= "true" | "false"
  505. null ::= "null"
  506. string ::= "\"" (
  507. [^"\\] |
  508. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
  509. )* "\"" ws
  510. ws ::= ([ \t\n] ws)?
  511. float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  512. integer ::= [0-9]+"""
  513. any_block = ""
  514. if "custom-class-any" in grammar:
  515. any_block = '''
  516. value ::= object | array | string | number | boolean | null
  517. object ::=
  518. "{" ws (
  519. string ":" ws value
  520. ("," ws string ":" ws value)*
  521. )? "}" ws
  522. array ::=
  523. "[" ws (
  524. value
  525. ("," ws value)*
  526. )? "]" ws
  527. number ::= integer | float'''
  528. markdown_code_block_grammar = ""
  529. if "markdown-code-block" in grammar:
  530. markdown_code_block_grammar = r'''
  531. markdown-code-block ::= opening-triple-ticks markdown-code-block-content closing-triple-ticks
  532. markdown-code-block-content ::= ( [^`] | "`" [^`] | "`" "`" [^`] )*
  533. opening-triple-ticks ::= "```" "python" "\n" | "```" "c" "\n" | "```" "cpp" "\n" | "```" "txt" "\n" | "```" "text" "\n" | "```" "json" "\n" | "```" "javascript" "\n" | "```" "css" "\n" | "```" "html" "\n" | "```" "markdown" "\n"
  534. closing-triple-ticks ::= "```" "\n"'''
  535. if "triple-quoted-string" in grammar:
  536. markdown_code_block_grammar = r"""
  537. triple-quoted-string ::= triple-quotes triple-quoted-string-content triple-quotes
  538. triple-quoted-string-content ::= ( [^'] | "'" [^'] | "'" "'" [^'] )*
  539. triple-quotes ::= "'''" """
  540. return "\n" + '\n'.join(additional_grammar) + any_block + primitive_grammar + markdown_code_block_grammar
  541. def generate_field_markdown(field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1) -> str:
  542. indent = ' ' * depth
  543. field_markdown = f"{indent}- **{field_name}** (`{field_type.__name__}`): "
  544. # Extracting field description from Pydantic Field using __model_fields__
  545. field_info = model.model_fields.get(field_name)
  546. field_description = field_info.description if field_info and field_info.description else "No description available."
  547. field_markdown += field_description + '\n'
  548. # Handling nested BaseModel fields
  549. if isclass(field_type) and issubclass(field_type, BaseModel):
  550. field_markdown += f"{indent} - Details:\n"
  551. for name, type_ in field_type.__annotations__.items():
  552. field_markdown += generate_field_markdown(name, type_, field_type, depth + 2)
  553. return field_markdown
  554. def generate_markdown_report(pydantic_models: List[Type[BaseModel]]) -> str:
  555. markdown = ""
  556. for model in pydantic_models:
  557. markdown += f"### {format_model_and_field_name(model.__name__)}\n"
  558. # Check if the model's docstring is different from BaseModel's docstring
  559. class_doc = getdoc(model)
  560. base_class_doc = getdoc(BaseModel)
  561. class_description = class_doc if class_doc and class_doc != base_class_doc else "No specific description available."
  562. markdown += f"{class_description}\n\n"
  563. markdown += "#### Fields\n"
  564. if isclass(model) and issubclass(model, BaseModel):
  565. for name, field_type in model.__annotations__.items():
  566. markdown += generate_field_markdown(format_model_and_field_name(name), field_type, model)
  567. markdown += "\n"
  568. return markdown
  569. def format_json_example(example: dict, depth: int) -> str:
  570. """
  571. Format a JSON example into a readable string with indentation.
  572. Args:
  573. example (dict): JSON example to be formatted.
  574. depth (int): Indentation depth.
  575. Returns:
  576. str: Formatted JSON example string.
  577. """
  578. indent = ' ' * depth
  579. formatted_example = '{\n'
  580. for key, value in example.items():
  581. value_text = f"'{value}'" if isinstance(value, str) else value
  582. formatted_example += f"{indent}{key}: {value_text},\n"
  583. formatted_example = formatted_example.rstrip(',\n') + '\n' + indent + '}'
  584. return formatted_example
  585. def generate_text_documentation(pydantic_models: List[Type[BaseModel]], model_prefix="Model",
  586. fields_prefix="Fields", documentation_with_field_description=True) -> str:
  587. """
  588. Generate text documentation for a list of Pydantic models.
  589. Args:
  590. pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
  591. model_prefix (str): Prefix for the model section.
  592. fields_prefix (str): Prefix for the fields section.
  593. documentation_with_field_description (bool): Include field descriptions in the documentation.
  594. Returns:
  595. str: Generated text documentation.
  596. """
  597. documentation = ""
  598. pyd_models = [(model, True) for model in pydantic_models]
  599. for model, add_prefix in pyd_models:
  600. if add_prefix:
  601. documentation += f"{model_prefix}: {format_model_and_field_name(model.__name__)}\n"
  602. else:
  603. documentation += f"Model: {format_model_and_field_name(model.__name__)}\n"
  604. # Handling multi-line model description with proper indentation
  605. class_doc = getdoc(model)
  606. base_class_doc = getdoc(BaseModel)
  607. class_description = class_doc if class_doc and class_doc != base_class_doc else ""
  608. if class_description != "":
  609. documentation += " Description: "
  610. documentation += "\n" + format_multiline_description(class_description, 2) + "\n"
  611. if add_prefix:
  612. # Indenting the fields section
  613. documentation += f" {fields_prefix}:\n"
  614. else:
  615. documentation += f" Fields:\n"
  616. if isclass(model) and issubclass(model, BaseModel):
  617. for name, field_type in model.__annotations__.items():
  618. # if name == "markdown_code_block":
  619. # continue
  620. if get_origin(field_type) == list:
  621. element_type = get_args(field_type)[0]
  622. if isclass(element_type) and issubclass(element_type, BaseModel):
  623. pyd_models.append((element_type, False))
  624. if get_origin(field_type) == Union:
  625. element_types = get_args(field_type)
  626. for element_type in element_types:
  627. if isclass(element_type) and issubclass(element_type, BaseModel):
  628. pyd_models.append((element_type, False))
  629. documentation += generate_field_text(name, field_type, model,
  630. documentation_with_field_description=documentation_with_field_description)
  631. documentation += "\n"
  632. if hasattr(model, 'Config') and hasattr(model.Config,
  633. 'json_schema_extra') and 'example' in model.Config.json_schema_extra:
  634. documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n"
  635. json_example = json.dumps(model.Config.json_schema_extra['example'])
  636. documentation += format_multiline_description(json_example, 2) + "\n"
  637. return documentation
  638. def generate_field_text(field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
  639. documentation_with_field_description=True) -> str:
  640. """
  641. Generate text documentation for a Pydantic model field.
  642. Args:
  643. field_name (str): Name of the field.
  644. field_type (Type[Any]): Type of the field.
  645. model (Type[BaseModel]): Pydantic model class.
  646. depth (int): Indentation depth in the documentation.
  647. documentation_with_field_description (bool): Include field descriptions in the documentation.
  648. Returns:
  649. str: Generated text documentation for the field.
  650. """
  651. indent = ' ' * depth
  652. field_info = model.model_fields.get(field_name)
  653. field_description = field_info.description if field_info and field_info.description else ""
  654. if get_origin(field_type) == list:
  655. element_type = get_args(field_type)[0]
  656. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
  657. if field_description != "":
  658. field_text += ":\n"
  659. else:
  660. field_text += "\n"
  661. elif get_origin(field_type) == Union:
  662. element_types = get_args(field_type)
  663. types = []
  664. for element_type in element_types:
  665. types.append(format_model_and_field_name(element_type.__name__))
  666. field_text = f"{indent}{field_name} ({' or '.join(types)})"
  667. if field_description != "":
  668. field_text += ":\n"
  669. else:
  670. field_text += "\n"
  671. else:
  672. field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})"
  673. if field_description != "":
  674. field_text += ":\n"
  675. else:
  676. field_text += "\n"
  677. if not documentation_with_field_description:
  678. return field_text
  679. if field_description != "":
  680. field_text += f"{indent} Description: " + field_description + "\n"
  681. # Check for and include field-specific examples if available
  682. if hasattr(model, 'Config') and hasattr(model.Config,
  683. 'json_schema_extra') and 'example' in model.Config.json_schema_extra:
  684. field_example = model.Config.json_schema_extra['example'].get(field_name)
  685. if field_example is not None:
  686. example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
  687. field_text += f"{indent} Example: {example_text}\n"
  688. if isclass(field_type) and issubclass(field_type, BaseModel):
  689. field_text += f"{indent} Details:\n"
  690. for name, type_ in field_type.__annotations__.items():
  691. field_text += generate_field_text(name, type_, field_type, depth + 2)
  692. return field_text
  693. def format_multiline_description(description: str, indent_level: int) -> str:
  694. """
  695. Format a multiline description with proper indentation.
  696. Args:
  697. description (str): Multiline description.
  698. indent_level (int): Indentation level.
  699. Returns:
  700. str: Formatted multiline description.
  701. """
  702. indent = ' ' * indent_level
  703. return indent + description.replace('\n', '\n' + indent)
  704. def save_gbnf_grammar_and_documentation(grammar, documentation, grammar_file_path="./grammar.gbnf",
  705. documentation_file_path="./grammar_documentation.md"):
  706. """
  707. Save GBNF grammar and documentation to specified files.
  708. Args:
  709. grammar (str): GBNF grammar string.
  710. documentation (str): Documentation string.
  711. grammar_file_path (str): File path to save the GBNF grammar.
  712. documentation_file_path (str): File path to save the documentation.
  713. Returns:
  714. None
  715. """
  716. try:
  717. with open(grammar_file_path, 'w') as file:
  718. file.write(grammar + get_primitive_grammar(grammar))
  719. print(f"Grammar successfully saved to {grammar_file_path}")
  720. except IOError as e:
  721. print(f"An error occurred while saving the grammar file: {e}")
  722. try:
  723. with open(documentation_file_path, 'w') as file:
  724. file.write(documentation)
  725. print(f"Documentation successfully saved to {documentation_file_path}")
  726. except IOError as e:
  727. print(f"An error occurred while saving the documentation file: {e}")
  728. def remove_empty_lines(string):
  729. """
  730. Remove empty lines from a string.
  731. Args:
  732. string (str): Input string.
  733. Returns:
  734. str: String with empty lines removed.
  735. """
  736. lines = string.splitlines()
  737. non_empty_lines = [line for line in lines if line.strip() != ""]
  738. string_no_empty_lines = "\n".join(non_empty_lines)
  739. return string_no_empty_lines
  740. def generate_and_save_gbnf_grammar_and_documentation(pydantic_model_list,
  741. grammar_file_path="./generated_grammar.gbnf",
  742. documentation_file_path="./generated_grammar_documentation.md",
  743. outer_object_name: str = None,
  744. outer_object_content: str = None,
  745. model_prefix: str = "Output Model",
  746. fields_prefix: str = "Output Fields",
  747. list_of_outputs: bool = False,
  748. documentation_with_field_description=True):
  749. """
  750. Generate GBNF grammar and documentation, and save them to specified files.
  751. Args:
  752. pydantic_model_list: List of Pydantic model classes.
  753. grammar_file_path (str): File path to save the generated GBNF grammar.
  754. documentation_file_path (str): File path to save the generated documentation.
  755. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  756. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  757. model_prefix (str): Prefix for the model section in the documentation.
  758. fields_prefix (str): Prefix for the fields section in the documentation.
  759. list_of_outputs (bool): Whether the output is a list of items.
  760. documentation_with_field_description (bool): Include field descriptions in the documentation.
  761. Returns:
  762. None
  763. """
  764. documentation = generate_text_documentation(pydantic_model_list, model_prefix, fields_prefix,
  765. documentation_with_field_description=documentation_with_field_description)
  766. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name,
  767. outer_object_content, list_of_outputs)
  768. grammar = remove_empty_lines(grammar)
  769. save_gbnf_grammar_and_documentation(grammar, documentation, grammar_file_path, documentation_file_path)
  770. def generate_gbnf_grammar_and_documentation(pydantic_model_list, outer_object_name: str = None,
  771. outer_object_content: str = None,
  772. model_prefix: str = "Output Model",
  773. fields_prefix: str = "Output Fields", list_of_outputs: bool = False,
  774. documentation_with_field_description=True):
  775. """
  776. Generate GBNF grammar and documentation for a list of Pydantic models.
  777. Args:
  778. pydantic_model_list: List of Pydantic model classes.
  779. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  780. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  781. model_prefix (str): Prefix for the model section in the documentation.
  782. fields_prefix (str): Prefix for the fields section in the documentation.
  783. list_of_outputs (bool): Whether the output is a list of items.
  784. documentation_with_field_description (bool): Include field descriptions in the documentation.
  785. Returns:
  786. tuple: GBNF grammar string, documentation string.
  787. """
  788. documentation = generate_text_documentation(copy(pydantic_model_list), model_prefix, fields_prefix,
  789. documentation_with_field_description=documentation_with_field_description)
  790. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name,
  791. outer_object_content, list_of_outputs)
  792. grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar))
  793. return grammar, documentation
  794. def generate_gbnf_grammar_and_documentation_from_dictionaries(dictionaries: List[dict],
  795. outer_object_name: str = None,
  796. outer_object_content: str = None,
  797. model_prefix: str = "Output Model",
  798. fields_prefix: str = "Output Fields",
  799. list_of_outputs: bool = False,
  800. documentation_with_field_description=True):
  801. """
  802. Generate GBNF grammar and documentation from a list of dictionaries.
  803. Args:
  804. dictionaries (List[dict]): List of dictionaries representing Pydantic models.
  805. outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
  806. outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
  807. model_prefix (str): Prefix for the model section in the documentation.
  808. fields_prefix (str): Prefix for the fields section in the documentation.
  809. list_of_outputs (bool): Whether the output is a list of items.
  810. documentation_with_field_description (bool): Include field descriptions in the documentation.
  811. Returns:
  812. tuple: GBNF grammar string, documentation string.
  813. """
  814. pydantic_model_list = create_dynamic_models_from_dictionaries(dictionaries)
  815. documentation = generate_text_documentation(copy(pydantic_model_list), model_prefix, fields_prefix,
  816. documentation_with_field_description=documentation_with_field_description)
  817. grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name,
  818. outer_object_content, list_of_outputs)
  819. grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar))
  820. return grammar, documentation
  821. def create_dynamic_model_from_function(func: Callable):
  822. """
  823. Creates a dynamic Pydantic model from a given function's type hints and adds the function as a 'run' method.
  824. Args:
  825. func (Callable): A function with type hints from which to create the model.
  826. Returns:
  827. A dynamic Pydantic model class with the provided function as a 'run' method.
  828. """
  829. # Extracting type hints from the provided function
  830. type_hints = get_type_hints(func)
  831. type_hints.pop('return', None)
  832. # Handling default values and annotations
  833. dynamic_fields = {}
  834. defaults = getattr(func, '__defaults__', ()) or ()
  835. defaults_index = len(type_hints) - len(defaults)
  836. for index, (name, typ) in enumerate(type_hints.items()):
  837. if index >= defaults_index:
  838. default_value = defaults[index - defaults_index]
  839. dynamic_fields[name] = (typ, default_value)
  840. else:
  841. dynamic_fields[name] = (typ, ...)
  842. # Creating the dynamic model
  843. dynamicModel = create_model(f'{func.__name__}', **dynamic_fields)
  844. dynamicModel.__doc__ = getdoc(func)
  845. # Wrapping the original function to handle instance 'self'
  846. def run_method_wrapper(self):
  847. func_args = {name: getattr(self, name) for name in type_hints}
  848. return func(**func_args)
  849. # Adding the wrapped function as a 'run' method
  850. setattr(dynamicModel, 'run', run_method_wrapper)
  851. return dynamicModel
  852. def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
  853. """
  854. Add a 'run' method to a dynamic Pydantic model, using the provided function.
  855. Args:
  856. - model (Type[BaseModel]): Dynamic Pydantic model class.
  857. - func (Callable): Function to be added as a 'run' method to the model.
  858. Returns:
  859. - Type[BaseModel]: Pydantic model class with the added 'run' method.
  860. """
  861. def run_method_wrapper(self):
  862. func_args = {name: getattr(self, name) for name in model.model_fields}
  863. return func(**func_args)
  864. # Adding the wrapped function as a 'run' method
  865. setattr(model, 'run', run_method_wrapper)
  866. return model
  867. def create_dynamic_models_from_dictionaries(dictionaries: List[dict]):
  868. """
  869. Create a list of dynamic Pydantic model classes from a list of dictionaries.
  870. Args:
  871. - dictionaries (List[dict]): List of dictionaries representing model structures.
  872. Returns:
  873. - List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
  874. """
  875. dynamic_models = []
  876. for func in dictionaries:
  877. model_name = format_model_and_field_name(func.get("name", ""))
  878. dyn_model = convert_dictionary_to_to_pydantic_model(func, model_name)
  879. dynamic_models.append(dyn_model)
  880. return dynamic_models
  881. def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
  882. output = {}
  883. for model in pydantic_model_list:
  884. output[format_model_and_field_name(model.__name__)] = model
  885. return output
  886. from enum import Enum
  887. def json_schema_to_python_types(schema):
  888. type_map = {
  889. 'any': Any,
  890. 'string': str,
  891. 'number': float,
  892. 'integer': int,
  893. 'boolean': bool,
  894. 'array': list,
  895. }
  896. return type_map[schema]
  897. def list_to_enum(enum_name, values):
  898. return Enum(enum_name, {value: value for value in values})
  899. def convert_dictionary_to_to_pydantic_model(dictionary: dict, model_name: str = 'CustomModel') -> Type[BaseModel]:
  900. """
  901. Convert a dictionary to a Pydantic model class.
  902. Args:
  903. - dictionary (dict): Dictionary representing the model structure.
  904. - model_name (str): Name of the generated Pydantic model.
  905. Returns:
  906. - Type[BaseModel]: Generated Pydantic model class.
  907. """
  908. fields = {}
  909. if "properties" in dictionary:
  910. for field_name, field_data in dictionary.get("properties", {}).items():
  911. if field_data == 'object':
  912. submodel = convert_dictionary_to_to_pydantic_model(dictionary, f'{model_name}_{field_name}')
  913. fields[field_name] = (submodel, ...)
  914. else:
  915. field_type = field_data.get('type', 'str')
  916. if field_data.get("enum", []):
  917. fields[field_name] = (list_to_enum(field_name, field_data.get("enum", [])), ...)
  918. if field_type == "array":
  919. items = field_data.get("items", {})
  920. if items != {}:
  921. array = {"properties": items}
  922. array_type = convert_dictionary_to_to_pydantic_model(array, f'{model_name}_{field_name}_items')
  923. fields[field_name] = (List[array_type], ...)
  924. else:
  925. fields[field_name] = (list, ...)
  926. elif field_type == 'object':
  927. submodel = convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}_{field_name}')
  928. fields[field_name] = (submodel, ...)
  929. else:
  930. field_type = json_schema_to_python_types(field_type)
  931. fields[field_name] = (field_type, ...)
  932. if "function" in dictionary:
  933. for field_name, field_data in dictionary.get("function", {}).items():
  934. if field_name == "name":
  935. model_name = field_data
  936. elif field_name == "description":
  937. fields["__doc__"] = field_data
  938. elif field_name == "parameters":
  939. return convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}')
  940. if "parameters" in dictionary:
  941. field_data = {"function": dictionary}
  942. return convert_dictionary_to_to_pydantic_model(field_data, f'{model_name}')
  943. custom_model = create_model(model_name, **fields)
  944. return custom_model