Coverage for src\funcall\params_to_schema.py: 77%

118 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-28 01:17 +0900

1import dataclasses 

2import types 

3from dataclasses import fields, is_dataclass 

4from typing import Any, Literal, get_args, get_origin 

5from typing import Any as TypingAny 

6from typing import Union as TypingUnion 

7 

8from pydantic import BaseModel, create_model 

9 

10 

11def _create_union_type(union_types: tuple) -> object: 

12 """Create a Union type, handling compatibility issues""" 

13 try: 

14 return TypingUnion[union_types] # noqa: UP007 

15 except TypeError: 

16 return TypingUnion.__getitem__(union_types) # type: ignore 

17 

18 

19def _handle_tuple_type(args: tuple) -> type: 

20 """Handle Tuple type conversion""" 

21 if not args: 

22 return list[TypingAny] 

23 

24 # Tuple[T, ...] -> List[T] 

25 if len(args) == 2 and args[1] is Ellipsis: 

26 item_type = to_field_type(args[0]) 

27 return list[item_type] 

28 

29 # Tuple[T1, T2, ...] -> List[Union[T1, T2, ...]] 

30 item_types = tuple(to_field_type(a) for a in args) 

31 if len(item_types) == 1: 

32 return list[item_types[0]] 

33 

34 union_type = _create_union_type(item_types) 

35 return list[union_type] 

36 

37 

38def _dataclass_to_pydantic_model(dataclass_type: type) -> type: 

39 """Convert a dataclass to a Pydantic Model""" 

40 model_fields = {} 

41 

42 for field in fields(dataclass_type): 

43 # Determine the default value of the field 

44 if field.default is not dataclasses.MISSING: 

45 default_value = field.default 

46 elif field.default_factory is not dataclasses.MISSING: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 default_value = field.default_factory 

48 else: 

49 default_value = ... 

50 

51 model_fields[field.name] = (field.type, default_value) 

52 

53 # Create Pydantic Model 

54 model = create_model(dataclass_type.__name__, **model_fields) 

55 

56 # Add field descriptions 

57 _add_field_descriptions(model, dataclass_type) 

58 

59 return model 

60 

61 

62def _add_field_descriptions(model: type, dataclass_type: type) -> None: 

63 """Add descriptions to Pydantic Model fields""" 

64 for field in fields(dataclass_type): 

65 if hasattr(field, "metadata") and "description" in field.metadata: 

66 description = field.metadata["description"] 

67 if hasattr(model, "model_fields") and field.name in model.model_fields: 67 ↛ 64line 67 didn't jump to line 64 because the condition on line 67 was always true

68 model.model_fields[field.name].description = description 

69 

70 

71def to_field_type(param: type) -> type: # noqa: C901, PLR0911 

72 """ 

73 Convert various type annotations to field types. 

74 """ 

75 if param is None: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 return type(None) 

77 

78 origin = get_origin(param) 

79 args = get_args(param) 

80 

81 # Check if it's a Pydantic BaseModel 

82 if isinstance(param, type) and issubclass(param, BaseModel): 

83 return param 

84 

85 # Check if it's a dataclass 

86 if is_dataclass(param): 

87 return _dataclass_to_pydantic_model(param) 

88 

89 # Handle generic types 

90 if origin is not None: 

91 # Union/Optional (compatible with 3.10+ X | Y) 

92 if origin is TypingUnion or (hasattr(types, "UnionType") and origin is types.UnionType): 

93 union_types = tuple(to_field_type(a) for a in args) 

94 return _create_union_type(union_types) # type: ignore 

95 

96 # List 

97 if origin is list: 

98 item_type = to_field_type(args[0]) if args else TypingAny 

99 return list[item_type] 

100 

101 # Dict - provide clearer error message 

102 if origin is dict: 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was always true

103 msg = f"Dict type {param} is not supported directly, use pydantic BaseModel or dataclass instead." 

104 raise TypeError(msg) 

105 

106 # Tuple 

107 if origin is tuple: 

108 return _handle_tuple_type(args) 

109 

110 # Basic type handling 

111 if isinstance(param, type): 111 ↛ 118line 111 didn't jump to line 118 because the condition on line 111 was always true

112 if param is dict: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 msg = "Dict type is not supported directly, use pydantic BaseModel or dataclass instead." 

114 raise TypeError(msg) 

115 return param 

116 

117 # If none match, raise error 

118 msg = f"Unsupported param type: {param} (type: {type(param)})" 

119 raise TypeError(msg) 

120 

121 

122def params_to_schema(params: list[Any], target: Literal["openai", "litellm"] = "openai") -> dict[str, Any]: 

123 """ 

124 Read a parameter list, which can contain various types, dataclasses, pydantic models, basic types, even nested or nested in lists. 

125 Output a jsonschema describing this set of parameters. 

126 

127 Args: 

128 params: List of parameter types 

129 target: Target platform ("openai" or "litellm"), defaults to "openai" 

130 """ 

131 if not isinstance(params, list): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 msg = "params must be a list" 

133 raise TypeError(msg) 

134 

135 # Build parameter model 

136 if not params: 

137 model = create_model("ParamsModel") 

138 else: 

139 model_fields = {} 

140 for i, p in enumerate(params): 

141 field_type = to_field_type(p) 

142 model_fields[f"param_{i}"] = (field_type, ...) 

143 

144 model = create_model("ParamsModel", **model_fields) 

145 

146 # Generate schema with explicit mode to avoid $refs 

147 schema = model.model_json_schema(mode="serialization") 

148 

149 # Apply normalization only for OpenAI target 

150 if target == "openai": 150 ↛ 154line 150 didn't jump to line 154 because the condition on line 150 was always true

151 _normalize_schema(schema) 

152 

153 # Remove $defs section if we want no refs 

154 if "$defs" in schema: 

155 schema = _inline_definitions(schema) 

156 

157 return schema 

158 

159 

160def _inline_definitions(schema: dict) -> dict: 

161 """ 

162 Inline all $ref definitions to avoid using references 

163 """ 

164 if "$defs" not in schema: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true

165 return schema 

166 

167 definitions = schema["$defs"] 

168 

169 def replace_refs(obj: Any) -> Any: # noqa: ANN401 

170 if isinstance(obj, dict): 

171 if "$ref" in obj: 

172 # Extract definition name from $ref 

173 ref_path = obj["$ref"] 

174 if ref_path.startswith("#/$defs/"): 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true

175 def_name = ref_path[8:] # Remove "#/$defs/" 

176 if def_name in definitions: 176 ↛ 181line 176 didn't jump to line 181 because the condition on line 176 was always true

177 # Replace $ref with inline definition 

178 inline_def = definitions[def_name].copy() 

179 # Recursively replace refs in the inline definition 

180 return replace_refs(inline_def) 

181 return obj 

182 # Recursively process all dict values 

183 return {k: replace_refs(v) for k, v in obj.items()} 

184 if isinstance(obj, list): 

185 # Recursively process all list items 

186 return [replace_refs(item) for item in obj] 

187 return obj 

188 

189 # Replace all $refs in the schema 

190 return replace_refs(schema) 

191 

192 

193def _normalize_schema(schema: dict | list) -> None: 

194 """ 

195 Normalize schema for OpenAI, add additionalProperties: false and fix required fields 

196 

197 Args: 

198 schema: The schema to normalize 

199 """ 

200 if isinstance(schema, dict): 

201 if schema.get("type") == "object": 

202 schema.setdefault("additionalProperties", False) 

203 # OpenAI Function Calling: required must contain all properties 

204 if "properties" in schema: 204 ↛ 208line 204 didn't jump to line 208 because the condition on line 204 was always true

205 schema["required"] = list(schema["properties"].keys()) 

206 

207 # Recursively handle nested objects 

208 for value in schema.values(): 

209 if isinstance(value, (dict, list)): 

210 _normalize_schema(value) 

211 

212 elif isinstance(schema, list): 212 ↛ exitline 212 didn't return from function '_normalize_schema' because the condition on line 212 was always true

213 for item in schema: 

214 if isinstance(item, (dict, list)): 

215 _normalize_schema(item)