Coverage for src/chat_limiter/adapters.py: 94%
158 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 20:26 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 20:26 +0100
1"""
2Provider-specific adapters for converting between our unified types and provider APIs.
3"""
5import time
6import warnings
7from abc import ABC, abstractmethod
8from typing import Any
10from .providers import Provider
11from .types import (
12 ChatCompletionRequest,
13 ChatCompletionResponse,
14 Choice,
15 Message,
16 MessageRole,
17 Usage,
18)
21class ProviderAdapter(ABC):
22 """Abstract base class for provider-specific adapters."""
24 @abstractmethod
25 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
26 """Convert our request format to provider-specific format."""
27 pass
29 @abstractmethod
30 def parse_response(
31 self,
32 response_data: dict[str, Any],
33 original_request: ChatCompletionRequest
34 ) -> ChatCompletionResponse:
35 """Convert provider response to our unified format."""
36 pass
38 @abstractmethod
39 def get_endpoint(self) -> str:
40 """Get the API endpoint for this provider."""
41 pass
44class OpenAIAdapter(ProviderAdapter):
45 """Adapter for OpenAI API."""
47 def is_reasoning_model(self, model_name: str) -> bool:
48 """Check if the model is a reasoning model that requires max_completion_tokens."""
49 return model_name.startswith(("o1", "o3", "o4"))
51 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
52 """Convert to OpenAI format."""
53 # Convert messages
54 messages: list[dict[str, Any]] = []
55 for msg in request.messages:
56 messages.append({
57 "role": msg.role.value,
58 "content": msg.content
59 })
61 # Build request
62 openai_request: dict[str, Any] = {
63 "model": request.model,
64 "messages": messages,
65 }
67 # Add optional parameters
68 if request.max_tokens is not None:
69 # Use max_completion_tokens for reasoning models (o1, o3, o4)
70 if self.is_reasoning_model(request.model):
71 openai_request["max_completion_tokens"] = request.max_tokens
72 else:
73 openai_request["max_tokens"] = request.max_tokens
75 # Handle temperature for reasoning models
76 if self.is_reasoning_model(request.model):
77 # For reasoning models, default to temperature=1
78 default_temperature = 1.0
80 if request.temperature is not None:
81 # If user provided a different temperature, warn them and use temperature=1
82 if request.temperature != default_temperature:
83 warnings.warn(
84 f"WARNING: Model '{request.model}' is a reasoning model that requires temperature=1. "
85 f"Your specified temperature={request.temperature} will be overridden to temperature=1.",
86 UserWarning
87 )
88 print(f"WARNING: Model '{request.model}' is a reasoning model that requires temperature=1. "
89 f"Your specified temperature={request.temperature} will be overridden to temperature=1.")
91 # Always use temperature=1 for reasoning models
92 openai_request["temperature"] = default_temperature
93 else:
94 # For non-reasoning models, use the provided temperature
95 if request.temperature is not None:
96 openai_request["temperature"] = request.temperature
98 if request.top_p is not None:
99 openai_request["top_p"] = request.top_p
100 if request.stop is not None:
101 openai_request["stop"] = request.stop
102 if request.stream:
103 openai_request["stream"] = request.stream
104 if request.frequency_penalty is not None:
105 openai_request["frequency_penalty"] = request.frequency_penalty
106 if request.presence_penalty is not None:
107 openai_request["presence_penalty"] = request.presence_penalty
109 return openai_request
111 def parse_response(
112 self,
113 response_data: dict[str, Any],
114 original_request: ChatCompletionRequest
115 ) -> ChatCompletionResponse:
116 """Parse OpenAI response."""
117 # Check for errors first
118 success = True
119 error_message = None
121 if "error" in response_data:
122 success = False
123 error_data = response_data["error"]
124 error_message = error_data.get("message", "Unknown error")
126 choices = []
127 for choice_data in response_data.get("choices", []):
128 message_data = choice_data.get("message", {})
129 message = Message(
130 role=MessageRole(message_data.get("role", "assistant")),
131 content=message_data.get("content", "")
132 )
133 choice = Choice(
134 index=choice_data.get("index", 0),
135 message=message,
136 finish_reason=choice_data.get("finish_reason")
137 )
138 choices.append(choice)
140 # Parse usage
141 usage = None
142 if "usage" in response_data:
143 usage_data = response_data["usage"]
144 usage = Usage(
145 prompt_tokens=usage_data.get("prompt_tokens", 0),
146 completion_tokens=usage_data.get("completion_tokens", 0),
147 total_tokens=usage_data.get("total_tokens", 0)
148 )
150 return ChatCompletionResponse(
151 id=response_data.get("id", ""),
152 model=response_data.get("model", original_request.model),
153 choices=choices,
154 usage=usage,
155 created=response_data.get("created"),
156 success=success,
157 error_message=error_message,
158 provider="openai",
159 raw_response=response_data
160 )
162 def get_endpoint(self) -> str:
163 return "/chat/completions"
166class AnthropicAdapter(ProviderAdapter):
167 """Adapter for Anthropic API."""
169 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
170 """Convert to Anthropic format."""
171 # Anthropic has a different message format
172 messages: list[dict[str, Any]] = []
173 system_message: str | None = None
175 for msg in request.messages:
176 if msg.role == MessageRole.SYSTEM:
177 # Anthropic handles system messages separately
178 system_message = msg.content
179 else:
180 messages.append({
181 "role": msg.role.value,
182 "content": msg.content
183 })
185 # Build request
186 anthropic_request: dict[str, Any] = {
187 "model": request.model,
188 "messages": messages,
189 "max_tokens": request.max_tokens or 1024, # Required for Anthropic
190 }
192 # Add system message if present
193 if system_message:
194 anthropic_request["system"] = system_message
196 # Add optional parameters
197 if request.temperature is not None:
198 anthropic_request["temperature"] = request.temperature
199 if request.top_p is not None:
200 anthropic_request["top_p"] = request.top_p
201 if request.stop is not None:
202 anthropic_request["stop_sequences"] = (
203 [request.stop] if isinstance(request.stop, str) else request.stop
204 )
205 if request.stream:
206 anthropic_request["stream"] = request.stream
207 if request.top_k is not None:
208 anthropic_request["top_k"] = request.top_k
210 return anthropic_request
212 def parse_response(
213 self,
214 response_data: dict[str, Any],
215 original_request: ChatCompletionRequest
216 ) -> ChatCompletionResponse:
217 """Parse Anthropic response."""
218 # Check for errors first
219 success = True
220 error_message = None
222 if "error" in response_data:
223 success = False
224 error_data = response_data["error"]
225 error_message = error_data.get("message", "Unknown error")
227 # Anthropic returns content differently
228 content_blocks = response_data.get("content", [])
229 content = ""
230 if content_blocks:
231 # Extract text from content blocks
232 for block in content_blocks:
233 if block.get("type") == "text":
234 content += block.get("text", "")
236 message = Message(
237 role=MessageRole.ASSISTANT,
238 content=content
239 )
241 choice = Choice(
242 index=0,
243 message=message,
244 finish_reason=response_data.get("stop_reason")
245 )
247 # Parse usage
248 usage = None
249 if "usage" in response_data:
250 usage_data = response_data["usage"]
251 usage = Usage(
252 prompt_tokens=usage_data.get("input_tokens", 0),
253 completion_tokens=usage_data.get("output_tokens", 0),
254 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0)
255 )
257 return ChatCompletionResponse(
258 id=response_data.get("id", ""),
259 model=response_data.get("model", original_request.model),
260 choices=[choice],
261 usage=usage,
262 created=int(time.time()), # Anthropic doesn't provide created timestamp
263 success=success,
264 error_message=error_message,
265 provider="anthropic",
266 raw_response=response_data
267 )
269 def get_endpoint(self) -> str:
270 return "/messages"
273class OpenRouterAdapter(ProviderAdapter):
274 """Adapter for OpenRouter API."""
276 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
277 """Convert to OpenRouter format (similar to OpenAI)."""
278 # OpenRouter uses OpenAI-compatible format
279 messages: list[dict[str, Any]] = []
280 for msg in request.messages:
281 messages.append({
282 "role": msg.role.value,
283 "content": msg.content
284 })
286 # Build request
287 openrouter_request: dict[str, Any] = {
288 "model": request.model,
289 "messages": messages,
290 }
292 # Add optional parameters
293 if request.max_tokens is not None:
294 openrouter_request["max_tokens"] = request.max_tokens
295 if request.temperature is not None:
296 openrouter_request["temperature"] = request.temperature
297 if request.top_p is not None:
298 openrouter_request["top_p"] = request.top_p
299 if request.stop is not None:
300 openrouter_request["stop"] = request.stop
301 if request.stream:
302 openrouter_request["stream"] = request.stream
303 if request.frequency_penalty is not None:
304 openrouter_request["frequency_penalty"] = request.frequency_penalty
305 if request.presence_penalty is not None:
306 openrouter_request["presence_penalty"] = request.presence_penalty
307 if request.top_k is not None:
308 openrouter_request["top_k"] = request.top_k
310 return openrouter_request
312 def parse_response(
313 self,
314 response_data: dict[str, Any],
315 original_request: ChatCompletionRequest
316 ) -> ChatCompletionResponse:
317 """Parse OpenRouter response (similar to OpenAI)."""
318 # Check for errors first
319 success = True
320 error_message = None
322 if "error" in response_data:
323 success = False
324 error_data = response_data["error"]
325 error_message = error_data.get("message", "Unknown error")
327 choices = []
328 for choice_data in response_data.get("choices", []):
329 message_data = choice_data.get("message", {})
330 message = Message(
331 role=MessageRole(message_data.get("role", "assistant")),
332 content=message_data.get("content", "")
333 )
334 choice = Choice(
335 index=choice_data.get("index", 0),
336 message=message,
337 finish_reason=choice_data.get("finish_reason")
338 )
339 choices.append(choice)
341 # Parse usage
342 usage = None
343 if "usage" in response_data:
344 usage_data = response_data["usage"]
345 usage = Usage(
346 prompt_tokens=usage_data.get("prompt_tokens", 0),
347 completion_tokens=usage_data.get("completion_tokens", 0),
348 total_tokens=usage_data.get("total_tokens", 0)
349 )
351 return ChatCompletionResponse(
352 id=response_data.get("id", ""),
353 model=response_data.get("model", original_request.model),
354 choices=choices,
355 usage=usage,
356 created=response_data.get("created"),
357 success=success,
358 error_message=error_message,
359 provider="openrouter",
360 raw_response=response_data
361 )
363 def get_endpoint(self) -> str:
364 return "/chat/completions"
367# Provider adapter registry
368PROVIDER_ADAPTERS = {
369 Provider.OPENAI: OpenAIAdapter(),
370 Provider.ANTHROPIC: AnthropicAdapter(),
371 Provider.OPENROUTER: OpenRouterAdapter(),
372}
375def get_adapter(provider: Provider) -> ProviderAdapter:
376 """Get the appropriate adapter for a provider."""
377 return PROVIDER_ADAPTERS[provider]