Coverage for src/chat_limiter/models.py: 92%

171 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 14:11 +0100

1""" 

2Dynamic model discovery from provider APIs. 

3 

4This module provides functionality to query provider APIs for available models 

5instead of relying on hardcoded lists. 

6""" 

7 

8import asyncio 

9import logging 

10from dataclasses import dataclass 

11from datetime import datetime, timedelta 

12from typing import Any 

13 

14import httpx 

15 

16logger = logging.getLogger(__name__) 

17 

18# Cache for model lists to avoid hitting APIs too frequently 

19_model_cache: dict[str, dict[str, Any]] = {} 

20_cache_duration = timedelta(hours=1) # Cache models for 1 hour 

21 

22 

23@dataclass 

24class ModelDiscoveryResult: 

25 """Result of model discovery process.""" 

26 

27 # Discovery result 

28 found_provider: str | None = None 

29 model_found: bool = False 

30 

31 # All models found for each provider 

32 openai_models: set[str] | None = None 

33 anthropic_models: set[str] | None = None 

34 openrouter_models: set[str] | None = None 

35 

36 # Errors encountered during discovery 

37 errors: dict[str, str] | None = None 

38 

39 def get_all_models(self) -> dict[str, set[str]]: 

40 """Get all models organized by provider.""" 

41 result = {} 

42 if self.openai_models is not None: 

43 result["openai"] = self.openai_models 

44 if self.anthropic_models is not None: 

45 result["anthropic"] = self.anthropic_models 

46 if self.openrouter_models is not None: 

47 result["openrouter"] = self.openrouter_models 

48 return result 

49 

50 def get_total_models_found(self) -> int: 

51 """Get total number of models found across all providers.""" 

52 total = 0 

53 if self.openai_models: 

54 total += len(self.openai_models) 

55 if self.anthropic_models: 

56 total += len(self.anthropic_models) 

57 if self.openrouter_models: 

58 total += len(self.openrouter_models) 

59 return total 

60 

61 

62class ModelDiscovery: 

63 """Dynamic model discovery from provider APIs.""" 

64 

65 @staticmethod 

66 async def get_openai_models(api_key: str) -> set[str]: 

67 """Get available OpenAI models from the API.""" 

68 cache_key = f"openai_models_{hash(api_key)}" 

69 

70 # Check cache first 

71 if _model_cache.get(cache_key): 

72 cache_entry = _model_cache[cache_key] 

73 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

74 return cache_entry["models"] # type: ignore[no-any-return] 

75 

76 try: 

77 async with httpx.AsyncClient() as client: 

78 response = await client.get( 

79 "https://api.openai.com/v1/models", 

80 headers={"Authorization": f"Bearer {api_key}"}, 

81 timeout=10.0 

82 ) 

83 response.raise_for_status() 

84 

85 data = response.json() 

86 models = set() 

87 

88 for model in data.get("data", []): 

89 model_id = model.get("id", "") 

90 models.add(model_id) 

91 

92 # Cache the result 

93 _model_cache[cache_key] = { 

94 "models": models, 

95 "timestamp": datetime.now() 

96 } 

97 

98 logger.info(f"Retrieved {len(models)} OpenAI models from API") 

99 return models 

100 

101 except Exception as e: 

102 logger.warning(f"Failed to fetch OpenAI models: {e}") 

103 raise 

104 

105 @staticmethod 

106 async def get_anthropic_models(api_key: str) -> set[str]: 

107 """Get available Anthropic models from the API.""" 

108 cache_key = f"anthropic_models_{hash(api_key)}" 

109 

110 # Check cache first 

111 if _model_cache.get(cache_key): 

112 cache_entry = _model_cache[cache_key] 

113 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

114 return cache_entry["models"] # type: ignore[no-any-return] 

115 

116 try: 

117 async with httpx.AsyncClient() as client: 

118 response = await client.get( 

119 "https://api.anthropic.com/v1/models", 

120 headers={ 

121 "x-api-key": api_key, 

122 "anthropic-version": "2023-06-01" 

123 }, 

124 timeout=10.0 

125 ) 

126 response.raise_for_status() 

127 

128 data = response.json() 

129 models = set() 

130 

131 for model in data.get("data", []): 

132 model_id = model.get("id", "") 

133 models.add(model_id) 

134 

135 # Cache the result 

136 _model_cache[cache_key] = { 

137 "models": models, 

138 "timestamp": datetime.now() 

139 } 

140 

141 logger.info(f"Retrieved {len(models)} Anthropic models from API") 

142 return models 

143 

144 except Exception as e: 

145 logger.warning(f"Failed to fetch Anthropic models: {e}") 

146 raise 

147 

148 @staticmethod 

149 async def get_openrouter_models(api_key: str | None = None) -> set[str]: 

150 """Get available OpenRouter models from the API.""" 

151 cache_key = "openrouter_models" 

152 

153 # Check cache first 

154 if _model_cache.get(cache_key): 

155 cache_entry = _model_cache[cache_key] 

156 if datetime.now() - cache_entry["timestamp"] < _cache_duration: 

157 return cache_entry["models"] # type: ignore[no-any-return] 

158 

159 try: 

160 headers = {} 

161 if api_key: 

162 headers["Authorization"] = f"Bearer {api_key}" 

163 

164 async with httpx.AsyncClient() as client: 

165 response = await client.get( 

166 "https://openrouter.ai/api/v1/models", 

167 headers=headers, 

168 timeout=10.0 

169 ) 

170 response.raise_for_status() 

171 

172 data = response.json() 

173 models = set() 

174 

175 for model in data.get("data", []): 

176 model_id = model.get("id", "") 

177 if model_id: 

178 models.add(model_id) 

179 

180 # Cache the result 

181 _model_cache[cache_key] = { 

182 "models": models, 

183 "timestamp": datetime.now() 

184 } 

185 

186 logger.info(f"Retrieved {len(models)} OpenRouter models from API") 

187 return models 

188 

189 except Exception as e: 

190 logger.warning(f"Failed to fetch OpenRouter models: {e}") 

191 raise 

192 

193 @staticmethod 

194 def get_openai_models_sync(api_key: str) -> set[str]: 

195 """Synchronous version of get_openai_models.""" 

196 return asyncio.run(ModelDiscovery.get_openai_models(api_key)) 

197 

198 @staticmethod 

199 def get_anthropic_models_sync(api_key: str) -> set[str]: 

200 """Synchronous version of get_anthropic_models.""" 

201 return asyncio.run(ModelDiscovery.get_anthropic_models(api_key)) 

202 

203 @staticmethod 

204 def get_openrouter_models_sync(api_key: str | None = None) -> set[str]: 

205 """Synchronous version of get_openrouter_models.""" 

206 return asyncio.run(ModelDiscovery.get_openrouter_models(api_key)) 

207 

208 

209async def detect_provider_from_model_async( 

210 model: str, 

211 api_keys: dict[str, str] | None = None 

212) -> ModelDiscoveryResult: 

213 """ 

214 Detect provider from model name using live API queries. 

215 

216 Args: 

217 model: The model name to check 

218 api_keys: Dictionary of API keys {"openai": "sk-...", "anthropic": "sk-ant-..."} 

219 

220 Returns: 

221 ModelDiscoveryResult with discovery information 

222 """ 

223 if not api_keys: 

224 api_keys = {} 

225 

226 result = ModelDiscoveryResult(errors={}) 

227 

228 # First try simple pattern matching for known formats 

229 if "/" in model: # OpenRouter format 

230 result.found_provider = "openrouter" 

231 result.model_found = True 

232 # Still try to get OpenRouter models to populate the result 

233 try: 

234 result.openrouter_models = await ModelDiscovery.get_openrouter_models(api_keys.get("openrouter")) 

235 except Exception as e: 

236 result.errors["openrouter"] = str(e) 

237 return result 

238 

239 # Create all tasks 

240 tasks = [] 

241 

242 if api_keys.get("openai"): 

243 tasks.append(("openai", ModelDiscovery.get_openai_models(api_keys["openai"]))) 

244 

245 if api_keys.get("anthropic"): 

246 tasks.append(("anthropic", ModelDiscovery.get_anthropic_models(api_keys["anthropic"]))) 

247 

248 if api_keys.get("openrouter"): 

249 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models(api_keys["openrouter"]))) 

250 else: 

251 # OpenRouter doesn't require API key for model listing 

252 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models())) 

253 

254 # Use asyncio.gather to run all tasks concurrently and properly handle them 

255 try: 

256 # Extract just the coroutines for gather 

257 coroutines = [task[1] for task in tasks] 

258 provider_names = [task[0] for task in tasks] 

259 

260 # Wait for all results 

261 results = await asyncio.gather(*coroutines, return_exceptions=True) 

262 

263 # Process results and store all model information 

264 for provider_name, models_result in zip(provider_names, results): 

265 if isinstance(models_result, Exception): 

266 logger.debug(f"Failed to check {provider_name} for model {model}: {models_result}") 

267 result.errors[provider_name] = str(models_result) 

268 continue 

269 

270 # Store models in result 

271 if provider_name == "openai": 

272 result.openai_models = models_result 

273 elif provider_name == "anthropic": 

274 result.anthropic_models = models_result 

275 elif provider_name == "openrouter": 

276 result.openrouter_models = models_result 

277 

278 # Check if our target model was found 

279 if model in models_result and not result.model_found: 

280 result.found_provider = provider_name 

281 result.model_found = True 

282 

283 except Exception as e: 

284 logger.debug(f"Failed to run dynamic discovery for model {model}: {e}") 

285 result.errors["general"] = str(e) 

286 

287 return result 

288 

289 

290def detect_provider_from_model_sync( 

291 model: str, 

292 api_keys: dict[str, str] | None = None 

293) -> ModelDiscoveryResult: 

294 """Synchronous version of detect_provider_from_model_async.""" 

295 # Check if we're already in an async context 

296 try: 

297 loop = asyncio.get_running_loop() 

298 # We're in an async context, but need to run in sync mode 

299 # Create a new event loop in a thread 

300 import concurrent.futures 

301 

302 def run_in_thread(): 

303 return asyncio.run(detect_provider_from_model_async(model, api_keys)) 

304 

305 with concurrent.futures.ThreadPoolExecutor() as executor: 

306 future = executor.submit(run_in_thread) 

307 return future.result(timeout=30) # 30 second timeout 

308 

309 except RuntimeError: 

310 # No running loop, safe to use asyncio.run 

311 return asyncio.run(detect_provider_from_model_async(model, api_keys)) 

312 

313 

314def clear_model_cache() -> None: 

315 """Clear the model cache to force fresh API queries.""" 

316 global _model_cache 

317 _model_cache.clear() 

318 logger.info("Model cache cleared") 

319 

320 

321