Coverage for src/chat_limiter/models.py: 92%
171 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 14:11 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 14:11 +0100
1"""
2Dynamic model discovery from provider APIs.
4This module provides functionality to query provider APIs for available models
5instead of relying on hardcoded lists.
6"""
8import asyncio
9import logging
10from dataclasses import dataclass
11from datetime import datetime, timedelta
12from typing import Any
14import httpx
16logger = logging.getLogger(__name__)
18# Cache for model lists to avoid hitting APIs too frequently
19_model_cache: dict[str, dict[str, Any]] = {}
20_cache_duration = timedelta(hours=1) # Cache models for 1 hour
23@dataclass
24class ModelDiscoveryResult:
25 """Result of model discovery process."""
27 # Discovery result
28 found_provider: str | None = None
29 model_found: bool = False
31 # All models found for each provider
32 openai_models: set[str] | None = None
33 anthropic_models: set[str] | None = None
34 openrouter_models: set[str] | None = None
36 # Errors encountered during discovery
37 errors: dict[str, str] | None = None
39 def get_all_models(self) -> dict[str, set[str]]:
40 """Get all models organized by provider."""
41 result = {}
42 if self.openai_models is not None:
43 result["openai"] = self.openai_models
44 if self.anthropic_models is not None:
45 result["anthropic"] = self.anthropic_models
46 if self.openrouter_models is not None:
47 result["openrouter"] = self.openrouter_models
48 return result
50 def get_total_models_found(self) -> int:
51 """Get total number of models found across all providers."""
52 total = 0
53 if self.openai_models:
54 total += len(self.openai_models)
55 if self.anthropic_models:
56 total += len(self.anthropic_models)
57 if self.openrouter_models:
58 total += len(self.openrouter_models)
59 return total
62class ModelDiscovery:
63 """Dynamic model discovery from provider APIs."""
65 @staticmethod
66 async def get_openai_models(api_key: str) -> set[str]:
67 """Get available OpenAI models from the API."""
68 cache_key = f"openai_models_{hash(api_key)}"
70 # Check cache first
71 if _model_cache.get(cache_key):
72 cache_entry = _model_cache[cache_key]
73 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
74 return cache_entry["models"] # type: ignore[no-any-return]
76 try:
77 async with httpx.AsyncClient() as client:
78 response = await client.get(
79 "https://api.openai.com/v1/models",
80 headers={"Authorization": f"Bearer {api_key}"},
81 timeout=10.0
82 )
83 response.raise_for_status()
85 data = response.json()
86 models = set()
88 for model in data.get("data", []):
89 model_id = model.get("id", "")
90 models.add(model_id)
92 # Cache the result
93 _model_cache[cache_key] = {
94 "models": models,
95 "timestamp": datetime.now()
96 }
98 logger.info(f"Retrieved {len(models)} OpenAI models from API")
99 return models
101 except Exception as e:
102 logger.warning(f"Failed to fetch OpenAI models: {e}")
103 raise
105 @staticmethod
106 async def get_anthropic_models(api_key: str) -> set[str]:
107 """Get available Anthropic models from the API."""
108 cache_key = f"anthropic_models_{hash(api_key)}"
110 # Check cache first
111 if _model_cache.get(cache_key):
112 cache_entry = _model_cache[cache_key]
113 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
114 return cache_entry["models"] # type: ignore[no-any-return]
116 try:
117 async with httpx.AsyncClient() as client:
118 response = await client.get(
119 "https://api.anthropic.com/v1/models",
120 headers={
121 "x-api-key": api_key,
122 "anthropic-version": "2023-06-01"
123 },
124 timeout=10.0
125 )
126 response.raise_for_status()
128 data = response.json()
129 models = set()
131 for model in data.get("data", []):
132 model_id = model.get("id", "")
133 models.add(model_id)
135 # Cache the result
136 _model_cache[cache_key] = {
137 "models": models,
138 "timestamp": datetime.now()
139 }
141 logger.info(f"Retrieved {len(models)} Anthropic models from API")
142 return models
144 except Exception as e:
145 logger.warning(f"Failed to fetch Anthropic models: {e}")
146 raise
148 @staticmethod
149 async def get_openrouter_models(api_key: str | None = None) -> set[str]:
150 """Get available OpenRouter models from the API."""
151 cache_key = "openrouter_models"
153 # Check cache first
154 if _model_cache.get(cache_key):
155 cache_entry = _model_cache[cache_key]
156 if datetime.now() - cache_entry["timestamp"] < _cache_duration:
157 return cache_entry["models"] # type: ignore[no-any-return]
159 try:
160 headers = {}
161 if api_key:
162 headers["Authorization"] = f"Bearer {api_key}"
164 async with httpx.AsyncClient() as client:
165 response = await client.get(
166 "https://openrouter.ai/api/v1/models",
167 headers=headers,
168 timeout=10.0
169 )
170 response.raise_for_status()
172 data = response.json()
173 models = set()
175 for model in data.get("data", []):
176 model_id = model.get("id", "")
177 if model_id:
178 models.add(model_id)
180 # Cache the result
181 _model_cache[cache_key] = {
182 "models": models,
183 "timestamp": datetime.now()
184 }
186 logger.info(f"Retrieved {len(models)} OpenRouter models from API")
187 return models
189 except Exception as e:
190 logger.warning(f"Failed to fetch OpenRouter models: {e}")
191 raise
193 @staticmethod
194 def get_openai_models_sync(api_key: str) -> set[str]:
195 """Synchronous version of get_openai_models."""
196 return asyncio.run(ModelDiscovery.get_openai_models(api_key))
198 @staticmethod
199 def get_anthropic_models_sync(api_key: str) -> set[str]:
200 """Synchronous version of get_anthropic_models."""
201 return asyncio.run(ModelDiscovery.get_anthropic_models(api_key))
203 @staticmethod
204 def get_openrouter_models_sync(api_key: str | None = None) -> set[str]:
205 """Synchronous version of get_openrouter_models."""
206 return asyncio.run(ModelDiscovery.get_openrouter_models(api_key))
209async def detect_provider_from_model_async(
210 model: str,
211 api_keys: dict[str, str] | None = None
212) -> ModelDiscoveryResult:
213 """
214 Detect provider from model name using live API queries.
216 Args:
217 model: The model name to check
218 api_keys: Dictionary of API keys {"openai": "sk-...", "anthropic": "sk-ant-..."}
220 Returns:
221 ModelDiscoveryResult with discovery information
222 """
223 if not api_keys:
224 api_keys = {}
226 result = ModelDiscoveryResult(errors={})
228 # First try simple pattern matching for known formats
229 if "/" in model: # OpenRouter format
230 result.found_provider = "openrouter"
231 result.model_found = True
232 # Still try to get OpenRouter models to populate the result
233 try:
234 result.openrouter_models = await ModelDiscovery.get_openrouter_models(api_keys.get("openrouter"))
235 except Exception as e:
236 result.errors["openrouter"] = str(e)
237 return result
239 # Create all tasks
240 tasks = []
242 if api_keys.get("openai"):
243 tasks.append(("openai", ModelDiscovery.get_openai_models(api_keys["openai"])))
245 if api_keys.get("anthropic"):
246 tasks.append(("anthropic", ModelDiscovery.get_anthropic_models(api_keys["anthropic"])))
248 if api_keys.get("openrouter"):
249 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models(api_keys["openrouter"])))
250 else:
251 # OpenRouter doesn't require API key for model listing
252 tasks.append(("openrouter", ModelDiscovery.get_openrouter_models()))
254 # Use asyncio.gather to run all tasks concurrently and properly handle them
255 try:
256 # Extract just the coroutines for gather
257 coroutines = [task[1] for task in tasks]
258 provider_names = [task[0] for task in tasks]
260 # Wait for all results
261 results = await asyncio.gather(*coroutines, return_exceptions=True)
263 # Process results and store all model information
264 for provider_name, models_result in zip(provider_names, results):
265 if isinstance(models_result, Exception):
266 logger.debug(f"Failed to check {provider_name} for model {model}: {models_result}")
267 result.errors[provider_name] = str(models_result)
268 continue
270 # Store models in result
271 if provider_name == "openai":
272 result.openai_models = models_result
273 elif provider_name == "anthropic":
274 result.anthropic_models = models_result
275 elif provider_name == "openrouter":
276 result.openrouter_models = models_result
278 # Check if our target model was found
279 if model in models_result and not result.model_found:
280 result.found_provider = provider_name
281 result.model_found = True
283 except Exception as e:
284 logger.debug(f"Failed to run dynamic discovery for model {model}: {e}")
285 result.errors["general"] = str(e)
287 return result
290def detect_provider_from_model_sync(
291 model: str,
292 api_keys: dict[str, str] | None = None
293) -> ModelDiscoveryResult:
294 """Synchronous version of detect_provider_from_model_async."""
295 # Check if we're already in an async context
296 try:
297 loop = asyncio.get_running_loop()
298 # We're in an async context, but need to run in sync mode
299 # Create a new event loop in a thread
300 import concurrent.futures
302 def run_in_thread():
303 return asyncio.run(detect_provider_from_model_async(model, api_keys))
305 with concurrent.futures.ThreadPoolExecutor() as executor:
306 future = executor.submit(run_in_thread)
307 return future.result(timeout=30) # 30 second timeout
309 except RuntimeError:
310 # No running loop, safe to use asyncio.run
311 return asyncio.run(detect_provider_from_model_async(model, api_keys))
314def clear_model_cache() -> None:
315 """Clear the model cache to force fresh API queries."""
316 global _model_cache
317 _model_cache.clear()
318 logger.info("Model cache cleared")