Coverage for src/chat_limiter/adapters.py: 94%

158 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 20:26 +0100

1""" 

2Provider-specific adapters for converting between our unified types and provider APIs. 

3""" 

4 

5import time 

6import warnings 

7from abc import ABC, abstractmethod 

8from typing import Any 

9 

10from .providers import Provider 

11from .types import ( 

12 ChatCompletionRequest, 

13 ChatCompletionResponse, 

14 Choice, 

15 Message, 

16 MessageRole, 

17 Usage, 

18) 

19 

20 

21class ProviderAdapter(ABC): 

22 """Abstract base class for provider-specific adapters.""" 

23 

24 @abstractmethod 

25 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

26 """Convert our request format to provider-specific format.""" 

27 pass 

28 

29 @abstractmethod 

30 def parse_response( 

31 self, 

32 response_data: dict[str, Any], 

33 original_request: ChatCompletionRequest 

34 ) -> ChatCompletionResponse: 

35 """Convert provider response to our unified format.""" 

36 pass 

37 

38 @abstractmethod 

39 def get_endpoint(self) -> str: 

40 """Get the API endpoint for this provider.""" 

41 pass 

42 

43 

44class OpenAIAdapter(ProviderAdapter): 

45 """Adapter for OpenAI API.""" 

46 

47 def is_reasoning_model(self, model_name: str) -> bool: 

48 """Check if the model is a reasoning model that requires max_completion_tokens.""" 

49 return model_name.startswith(("o1", "o3", "o4")) 

50 

51 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

52 """Convert to OpenAI format.""" 

53 # Convert messages 

54 messages: list[dict[str, Any]] = [] 

55 for msg in request.messages: 

56 messages.append({ 

57 "role": msg.role.value, 

58 "content": msg.content 

59 }) 

60 

61 # Build request 

62 openai_request: dict[str, Any] = { 

63 "model": request.model, 

64 "messages": messages, 

65 } 

66 

67 # Add optional parameters 

68 if request.max_tokens is not None: 

69 # Use max_completion_tokens for reasoning models (o1, o3, o4) 

70 if self.is_reasoning_model(request.model): 

71 openai_request["max_completion_tokens"] = request.max_tokens 

72 else: 

73 openai_request["max_tokens"] = request.max_tokens 

74 

75 # Handle temperature for reasoning models 

76 if self.is_reasoning_model(request.model): 

77 # For reasoning models, default to temperature=1 

78 default_temperature = 1.0 

79 

80 if request.temperature is not None: 

81 # If user provided a different temperature, warn them and use temperature=1 

82 if request.temperature != default_temperature: 

83 warnings.warn( 

84 f"WARNING: Model '{request.model}' is a reasoning model that requires temperature=1. " 

85 f"Your specified temperature={request.temperature} will be overridden to temperature=1.", 

86 UserWarning 

87 ) 

88 print(f"WARNING: Model '{request.model}' is a reasoning model that requires temperature=1. " 

89 f"Your specified temperature={request.temperature} will be overridden to temperature=1.") 

90 

91 # Always use temperature=1 for reasoning models 

92 openai_request["temperature"] = default_temperature 

93 else: 

94 # For non-reasoning models, use the provided temperature 

95 if request.temperature is not None: 

96 openai_request["temperature"] = request.temperature 

97 

98 if request.top_p is not None: 

99 openai_request["top_p"] = request.top_p 

100 if request.stop is not None: 

101 openai_request["stop"] = request.stop 

102 if request.stream: 

103 openai_request["stream"] = request.stream 

104 if request.frequency_penalty is not None: 

105 openai_request["frequency_penalty"] = request.frequency_penalty 

106 if request.presence_penalty is not None: 

107 openai_request["presence_penalty"] = request.presence_penalty 

108 

109 return openai_request 

110 

111 def parse_response( 

112 self, 

113 response_data: dict[str, Any], 

114 original_request: ChatCompletionRequest 

115 ) -> ChatCompletionResponse: 

116 """Parse OpenAI response.""" 

117 # Check for errors first 

118 success = True 

119 error_message = None 

120 

121 if "error" in response_data: 

122 success = False 

123 error_data = response_data["error"] 

124 error_message = error_data.get("message", "Unknown error") 

125 

126 choices = [] 

127 for choice_data in response_data.get("choices", []): 

128 message_data = choice_data.get("message", {}) 

129 message = Message( 

130 role=MessageRole(message_data.get("role", "assistant")), 

131 content=message_data.get("content", "") 

132 ) 

133 choice = Choice( 

134 index=choice_data.get("index", 0), 

135 message=message, 

136 finish_reason=choice_data.get("finish_reason") 

137 ) 

138 choices.append(choice) 

139 

140 # Parse usage 

141 usage = None 

142 if "usage" in response_data: 

143 usage_data = response_data["usage"] 

144 usage = Usage( 

145 prompt_tokens=usage_data.get("prompt_tokens", 0), 

146 completion_tokens=usage_data.get("completion_tokens", 0), 

147 total_tokens=usage_data.get("total_tokens", 0) 

148 ) 

149 

150 return ChatCompletionResponse( 

151 id=response_data.get("id", ""), 

152 model=response_data.get("model", original_request.model), 

153 choices=choices, 

154 usage=usage, 

155 created=response_data.get("created"), 

156 success=success, 

157 error_message=error_message, 

158 provider="openai", 

159 raw_response=response_data 

160 ) 

161 

162 def get_endpoint(self) -> str: 

163 return "/chat/completions" 

164 

165 

166class AnthropicAdapter(ProviderAdapter): 

167 """Adapter for Anthropic API.""" 

168 

169 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

170 """Convert to Anthropic format.""" 

171 # Anthropic has a different message format 

172 messages: list[dict[str, Any]] = [] 

173 system_message: str | None = None 

174 

175 for msg in request.messages: 

176 if msg.role == MessageRole.SYSTEM: 

177 # Anthropic handles system messages separately 

178 system_message = msg.content 

179 else: 

180 messages.append({ 

181 "role": msg.role.value, 

182 "content": msg.content 

183 }) 

184 

185 # Build request 

186 anthropic_request: dict[str, Any] = { 

187 "model": request.model, 

188 "messages": messages, 

189 "max_tokens": request.max_tokens or 1024, # Required for Anthropic 

190 } 

191 

192 # Add system message if present 

193 if system_message: 

194 anthropic_request["system"] = system_message 

195 

196 # Add optional parameters 

197 if request.temperature is not None: 

198 anthropic_request["temperature"] = request.temperature 

199 if request.top_p is not None: 

200 anthropic_request["top_p"] = request.top_p 

201 if request.stop is not None: 

202 anthropic_request["stop_sequences"] = ( 

203 [request.stop] if isinstance(request.stop, str) else request.stop 

204 ) 

205 if request.stream: 

206 anthropic_request["stream"] = request.stream 

207 if request.top_k is not None: 

208 anthropic_request["top_k"] = request.top_k 

209 

210 return anthropic_request 

211 

212 def parse_response( 

213 self, 

214 response_data: dict[str, Any], 

215 original_request: ChatCompletionRequest 

216 ) -> ChatCompletionResponse: 

217 """Parse Anthropic response.""" 

218 # Check for errors first 

219 success = True 

220 error_message = None 

221 

222 if "error" in response_data: 

223 success = False 

224 error_data = response_data["error"] 

225 error_message = error_data.get("message", "Unknown error") 

226 

227 # Anthropic returns content differently 

228 content_blocks = response_data.get("content", []) 

229 content = "" 

230 if content_blocks: 

231 # Extract text from content blocks 

232 for block in content_blocks: 

233 if block.get("type") == "text": 

234 content += block.get("text", "") 

235 

236 message = Message( 

237 role=MessageRole.ASSISTANT, 

238 content=content 

239 ) 

240 

241 choice = Choice( 

242 index=0, 

243 message=message, 

244 finish_reason=response_data.get("stop_reason") 

245 ) 

246 

247 # Parse usage 

248 usage = None 

249 if "usage" in response_data: 

250 usage_data = response_data["usage"] 

251 usage = Usage( 

252 prompt_tokens=usage_data.get("input_tokens", 0), 

253 completion_tokens=usage_data.get("output_tokens", 0), 

254 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0) 

255 ) 

256 

257 return ChatCompletionResponse( 

258 id=response_data.get("id", ""), 

259 model=response_data.get("model", original_request.model), 

260 choices=[choice], 

261 usage=usage, 

262 created=int(time.time()), # Anthropic doesn't provide created timestamp 

263 success=success, 

264 error_message=error_message, 

265 provider="anthropic", 

266 raw_response=response_data 

267 ) 

268 

269 def get_endpoint(self) -> str: 

270 return "/messages" 

271 

272 

273class OpenRouterAdapter(ProviderAdapter): 

274 """Adapter for OpenRouter API.""" 

275 

276 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

277 """Convert to OpenRouter format (similar to OpenAI).""" 

278 # OpenRouter uses OpenAI-compatible format 

279 messages: list[dict[str, Any]] = [] 

280 for msg in request.messages: 

281 messages.append({ 

282 "role": msg.role.value, 

283 "content": msg.content 

284 }) 

285 

286 # Build request 

287 openrouter_request: dict[str, Any] = { 

288 "model": request.model, 

289 "messages": messages, 

290 } 

291 

292 # Add optional parameters 

293 if request.max_tokens is not None: 

294 openrouter_request["max_tokens"] = request.max_tokens 

295 if request.temperature is not None: 

296 openrouter_request["temperature"] = request.temperature 

297 if request.top_p is not None: 

298 openrouter_request["top_p"] = request.top_p 

299 if request.stop is not None: 

300 openrouter_request["stop"] = request.stop 

301 if request.stream: 

302 openrouter_request["stream"] = request.stream 

303 if request.frequency_penalty is not None: 

304 openrouter_request["frequency_penalty"] = request.frequency_penalty 

305 if request.presence_penalty is not None: 

306 openrouter_request["presence_penalty"] = request.presence_penalty 

307 if request.top_k is not None: 

308 openrouter_request["top_k"] = request.top_k 

309 

310 return openrouter_request 

311 

312 def parse_response( 

313 self, 

314 response_data: dict[str, Any], 

315 original_request: ChatCompletionRequest 

316 ) -> ChatCompletionResponse: 

317 """Parse OpenRouter response (similar to OpenAI).""" 

318 # Check for errors first 

319 success = True 

320 error_message = None 

321 

322 if "error" in response_data: 

323 success = False 

324 error_data = response_data["error"] 

325 error_message = error_data.get("message", "Unknown error") 

326 

327 choices = [] 

328 for choice_data in response_data.get("choices", []): 

329 message_data = choice_data.get("message", {}) 

330 message = Message( 

331 role=MessageRole(message_data.get("role", "assistant")), 

332 content=message_data.get("content", "") 

333 ) 

334 choice = Choice( 

335 index=choice_data.get("index", 0), 

336 message=message, 

337 finish_reason=choice_data.get("finish_reason") 

338 ) 

339 choices.append(choice) 

340 

341 # Parse usage 

342 usage = None 

343 if "usage" in response_data: 

344 usage_data = response_data["usage"] 

345 usage = Usage( 

346 prompt_tokens=usage_data.get("prompt_tokens", 0), 

347 completion_tokens=usage_data.get("completion_tokens", 0), 

348 total_tokens=usage_data.get("total_tokens", 0) 

349 ) 

350 

351 return ChatCompletionResponse( 

352 id=response_data.get("id", ""), 

353 model=response_data.get("model", original_request.model), 

354 choices=choices, 

355 usage=usage, 

356 created=response_data.get("created"), 

357 success=success, 

358 error_message=error_message, 

359 provider="openrouter", 

360 raw_response=response_data 

361 ) 

362 

363 def get_endpoint(self) -> str: 

364 return "/chat/completions" 

365 

366 

367# Provider adapter registry 

368PROVIDER_ADAPTERS = { 

369 Provider.OPENAI: OpenAIAdapter(), 

370 Provider.ANTHROPIC: AnthropicAdapter(), 

371 Provider.OPENROUTER: OpenRouterAdapter(), 

372} 

373 

374 

375def get_adapter(provider: Provider) -> ProviderAdapter: 

376 """Get the appropriate adapter for a provider.""" 

377 return PROVIDER_ADAPTERS[provider]