Coverage for src\funcall\params_to_schema.py: 77%
118 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-28 01:17 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-28 01:17 +0900
1import dataclasses
2import types
3from dataclasses import fields, is_dataclass
4from typing import Any, Literal, get_args, get_origin
5from typing import Any as TypingAny
6from typing import Union as TypingUnion
8from pydantic import BaseModel, create_model
11def _create_union_type(union_types: tuple) -> object:
12 """Create a Union type, handling compatibility issues"""
13 try:
14 return TypingUnion[union_types] # noqa: UP007
15 except TypeError:
16 return TypingUnion.__getitem__(union_types) # type: ignore
19def _handle_tuple_type(args: tuple) -> type:
20 """Handle Tuple type conversion"""
21 if not args:
22 return list[TypingAny]
24 # Tuple[T, ...] -> List[T]
25 if len(args) == 2 and args[1] is Ellipsis:
26 item_type = to_field_type(args[0])
27 return list[item_type]
29 # Tuple[T1, T2, ...] -> List[Union[T1, T2, ...]]
30 item_types = tuple(to_field_type(a) for a in args)
31 if len(item_types) == 1:
32 return list[item_types[0]]
34 union_type = _create_union_type(item_types)
35 return list[union_type]
38def _dataclass_to_pydantic_model(dataclass_type: type) -> type:
39 """Convert a dataclass to a Pydantic Model"""
40 model_fields = {}
42 for field in fields(dataclass_type):
43 # Determine the default value of the field
44 if field.default is not dataclasses.MISSING:
45 default_value = field.default
46 elif field.default_factory is not dataclasses.MISSING: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 default_value = field.default_factory
48 else:
49 default_value = ...
51 model_fields[field.name] = (field.type, default_value)
53 # Create Pydantic Model
54 model = create_model(dataclass_type.__name__, **model_fields)
56 # Add field descriptions
57 _add_field_descriptions(model, dataclass_type)
59 return model
62def _add_field_descriptions(model: type, dataclass_type: type) -> None:
63 """Add descriptions to Pydantic Model fields"""
64 for field in fields(dataclass_type):
65 if hasattr(field, "metadata") and "description" in field.metadata:
66 description = field.metadata["description"]
67 if hasattr(model, "model_fields") and field.name in model.model_fields: 67 ↛ 64line 67 didn't jump to line 64 because the condition on line 67 was always true
68 model.model_fields[field.name].description = description
71def to_field_type(param: type) -> type: # noqa: C901, PLR0911
72 """
73 Convert various type annotations to field types.
74 """
75 if param is None: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 return type(None)
78 origin = get_origin(param)
79 args = get_args(param)
81 # Check if it's a Pydantic BaseModel
82 if isinstance(param, type) and issubclass(param, BaseModel):
83 return param
85 # Check if it's a dataclass
86 if is_dataclass(param):
87 return _dataclass_to_pydantic_model(param)
89 # Handle generic types
90 if origin is not None:
91 # Union/Optional (compatible with 3.10+ X | Y)
92 if origin is TypingUnion or (hasattr(types, "UnionType") and origin is types.UnionType):
93 union_types = tuple(to_field_type(a) for a in args)
94 return _create_union_type(union_types) # type: ignore
96 # List
97 if origin is list:
98 item_type = to_field_type(args[0]) if args else TypingAny
99 return list[item_type]
101 # Dict - provide clearer error message
102 if origin is dict: 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was always true
103 msg = f"Dict type {param} is not supported directly, use pydantic BaseModel or dataclass instead."
104 raise TypeError(msg)
106 # Tuple
107 if origin is tuple:
108 return _handle_tuple_type(args)
110 # Basic type handling
111 if isinstance(param, type): 111 ↛ 118line 111 didn't jump to line 118 because the condition on line 111 was always true
112 if param is dict: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 msg = "Dict type is not supported directly, use pydantic BaseModel or dataclass instead."
114 raise TypeError(msg)
115 return param
117 # If none match, raise error
118 msg = f"Unsupported param type: {param} (type: {type(param)})"
119 raise TypeError(msg)
122def params_to_schema(params: list[Any], target: Literal["openai", "litellm"] = "openai") -> dict[str, Any]:
123 """
124 Read a parameter list, which can contain various types, dataclasses, pydantic models, basic types, even nested or nested in lists.
125 Output a jsonschema describing this set of parameters.
127 Args:
128 params: List of parameter types
129 target: Target platform ("openai" or "litellm"), defaults to "openai"
130 """
131 if not isinstance(params, list): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 msg = "params must be a list"
133 raise TypeError(msg)
135 # Build parameter model
136 if not params:
137 model = create_model("ParamsModel")
138 else:
139 model_fields = {}
140 for i, p in enumerate(params):
141 field_type = to_field_type(p)
142 model_fields[f"param_{i}"] = (field_type, ...)
144 model = create_model("ParamsModel", **model_fields)
146 # Generate schema with explicit mode to avoid $refs
147 schema = model.model_json_schema(mode="serialization")
149 # Apply normalization only for OpenAI target
150 if target == "openai": 150 ↛ 154line 150 didn't jump to line 154 because the condition on line 150 was always true
151 _normalize_schema(schema)
153 # Remove $defs section if we want no refs
154 if "$defs" in schema:
155 schema = _inline_definitions(schema)
157 return schema
160def _inline_definitions(schema: dict) -> dict:
161 """
162 Inline all $ref definitions to avoid using references
163 """
164 if "$defs" not in schema: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true
165 return schema
167 definitions = schema["$defs"]
169 def replace_refs(obj: Any) -> Any: # noqa: ANN401
170 if isinstance(obj, dict):
171 if "$ref" in obj:
172 # Extract definition name from $ref
173 ref_path = obj["$ref"]
174 if ref_path.startswith("#/$defs/"): 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true
175 def_name = ref_path[8:] # Remove "#/$defs/"
176 if def_name in definitions: 176 ↛ 181line 176 didn't jump to line 181 because the condition on line 176 was always true
177 # Replace $ref with inline definition
178 inline_def = definitions[def_name].copy()
179 # Recursively replace refs in the inline definition
180 return replace_refs(inline_def)
181 return obj
182 # Recursively process all dict values
183 return {k: replace_refs(v) for k, v in obj.items()}
184 if isinstance(obj, list):
185 # Recursively process all list items
186 return [replace_refs(item) for item in obj]
187 return obj
189 # Replace all $refs in the schema
190 return replace_refs(schema)
193def _normalize_schema(schema: dict | list) -> None:
194 """
195 Normalize schema for OpenAI, add additionalProperties: false and fix required fields
197 Args:
198 schema: The schema to normalize
199 """
200 if isinstance(schema, dict):
201 if schema.get("type") == "object":
202 schema.setdefault("additionalProperties", False)
203 # OpenAI Function Calling: required must contain all properties
204 if "properties" in schema: 204 ↛ 208line 204 didn't jump to line 208 because the condition on line 204 was always true
205 schema["required"] = list(schema["properties"].keys())
207 # Recursively handle nested objects
208 for value in schema.values():
209 if isinstance(value, (dict, list)):
210 _normalize_schema(value)
212 elif isinstance(schema, list): 212 ↛ exitline 212 didn't return from function '_normalize_schema' because the condition on line 212 was always true
213 for item in schema:
214 if isinstance(item, (dict, list)):
215 _normalize_schema(item)