Coverage for src\funcall\metadata.py: 90%
81 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-28 01:17 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-28 01:17 +0900
1import dataclasses
2import inspect
3from asyncio.log import logger
4from collections.abc import Callable
5from typing import Literal, get_type_hints
7from openai.types.responses import FunctionToolParam
8from pydantic import BaseModel
10from funcall.params_to_schema import params_to_schema
12from .types import LiteLLMFunctionToolParam, is_context_type, is_optional_type
15def generate_function_metadata(
16 func: Callable,
17 target: Literal["openai", "litellm"] = "openai",
18) -> FunctionToolParam | LiteLLMFunctionToolParam:
19 """
20 Generate function metadata for OpenAI or LiteLLM function calling.
22 Args:
23 func: The function to generate metadata for
24 target: Target platform ("openai" or "litellm")
26 Returns:
27 Function metadata in the appropriate format
28 """
29 signature = inspect.signature(func)
30 type_hints = get_type_hints(func)
31 description = func.__doc__.strip() if func.__doc__ else ""
33 # Extract non-context parameters
34 param_names, param_types, context_count = _extract_parameters(signature, type_hints)
36 if context_count > 1:
37 logger.warning(
38 "Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.",
39 func.__name__,
40 )
42 schema = params_to_schema(param_types)
44 # Handle single parameter case (dataclass or BaseModel)
45 if len(param_names) == 1:
46 metadata = _generate_single_param_metadata(
47 func,
48 param_types[0],
49 schema,
50 description,
51 target,
52 )
53 if metadata:
54 return metadata
56 # Handle multiple parameters case
57 return _generate_multi_param_metadata(func, param_names, schema, description, target)
60def _extract_parameters(signature: inspect.Signature, type_hints: dict) -> tuple[list[str], list[type], int]:
61 """Extract parameter information from function signature."""
62 param_names = []
63 param_types = []
64 context_count = 0
66 for name in signature.parameters:
67 hint = type_hints.get(name, str)
69 # Skip Context-type parameters
70 if is_context_type(hint):
71 context_count += 1
72 continue
74 param_names.append(name)
75 param_types.append(hint)
77 return param_names, param_types, context_count
80def _generate_single_param_metadata(
81 func: Callable,
82 param_type: type,
83 schema: dict,
84 description: str,
85 target: str,
86) -> FunctionToolParam | LiteLLMFunctionToolParam | None:
87 """Generate metadata for functions with a single dataclass/BaseModel parameter."""
88 if not (isinstance(param_type, type) and (dataclasses.is_dataclass(param_type) or (BaseModel and issubclass(param_type, BaseModel)))):
89 return None
91 prop = schema["properties"]["param_0"]
92 properties = prop["properties"]
93 required = prop.get("required", [])
94 additional_properties = prop.get("additionalProperties", False)
96 base_params = {
97 "type": "object",
98 "properties": properties,
99 "additionalProperties": additional_properties,
100 }
102 if target == "litellm":
103 model_fields = None
104 if BaseModel and issubclass(param_type, BaseModel): 104 ↛ 106line 104 didn't jump to line 106 because the condition on line 104 was always true
105 model_fields = param_type.model_fields
106 elif dataclasses.is_dataclass(param_type):
107 model_fields = {f.name: f for f in dataclasses.fields(param_type)}
108 litellm_required = []
109 for k in properties:
110 # 优先用 pydantic/dc 字段信息判断
111 is_optional = False
112 if model_fields and k in model_fields: 112 ↛ 120line 112 didn't jump to line 120 because the condition on line 112 was always true
113 if BaseModel and issubclass(param_type, BaseModel): 113 ↛ 117line 113 didn't jump to line 117 because the condition on line 113 was always true
114 ann = model_fields[k].annotation # type: ignore
115 is_optional = is_optional_type(ann) or model_fields[k].is_required is False # type: ignore
116 else:
117 ann = model_fields[k].type # type: ignore
118 is_optional = is_optional_type(ann) or (getattr(model_fields[k], "default", dataclasses.MISSING) is not dataclasses.MISSING) # type: ignore
119 else:
120 is_optional = k not in required
121 if not is_optional: 121 ↛ 109line 121 didn't jump to line 109 because the condition on line 121 was always true
122 litellm_required.append(k)
123 return {
124 "type": "function",
125 "function": {
126 "name": func.__name__,
127 "description": description,
128 "parameters": {
129 **base_params,
130 "required": litellm_required,
131 },
132 },
133 } # type: ignore
135 # OpenAI format
136 metadata: FunctionToolParam = {
137 "type": "function",
138 "name": func.__name__,
139 "description": description,
140 "parameters": {
141 **base_params,
142 "required": list(properties.keys()),
143 },
144 "strict": True,
145 }
146 return metadata
149def _generate_multi_param_metadata(
150 func: Callable,
151 param_names: list[str],
152 schema: dict,
153 description: str,
154 target: str,
155) -> FunctionToolParam | LiteLLMFunctionToolParam:
156 """Generate metadata for functions with multiple parameters."""
157 properties = {}
158 for i, name in enumerate(param_names):
159 properties[name] = schema["properties"][f"param_{i}"]
161 base_params = {
162 "type": "object",
163 "properties": properties,
164 "additionalProperties": False,
165 }
167 if target == "litellm":
168 sig = inspect.signature(func)
169 type_hints = get_type_hints(func)
170 litellm_required = []
171 for name in param_names:
172 hint = type_hints.get(name, str)
173 param = sig.parameters[name]
174 is_optional = is_optional_type(hint) or (param.default != inspect.Parameter.empty)
175 if not is_optional:
176 litellm_required.append(name)
177 return {
178 "type": "function",
179 "function": {
180 "name": func.__name__,
181 "description": description,
182 "parameters": {
183 **base_params,
184 "required": litellm_required,
185 },
186 },
187 } # type: ignore
189 # OpenAI format
190 metadata: FunctionToolParam = {
191 "type": "function",
192 "name": func.__name__,
193 "description": description,
194 "parameters": {
195 **base_params,
196 "required": list(param_names),
197 },
198 "strict": True,
199 }
201 return metadata