Coverage for src/chat_limiter/limiter.py: 82%
395 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 20:26 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 20:26 +0100
1"""
2Core rate limiter implementation using PyrateLimiter.
3"""
5import asyncio
6import logging
7import time
8from collections.abc import AsyncIterator, Iterator
9from contextlib import asynccontextmanager, contextmanager
10from dataclasses import dataclass, field
11from typing import Any
13import httpx
14from pyrate_limiter import Duration, Limiter, Rate
15from tenacity import (
16 retry,
17 retry_if_exception_type,
18 stop_after_attempt,
19 wait_exponential,
20)
22from .adapters import get_adapter
23from .providers import (
24 Provider,
25 ProviderConfig,
26 RateLimitInfo,
27 detect_provider_from_url,
28 extract_rate_limit_info,
29 get_provider_config,
30)
31from .types import (
32 ChatCompletionRequest,
33 ChatCompletionResponse,
34 Message,
35 MessageRole,
36 detect_provider_from_model,
37)
39logger = logging.getLogger(__name__)
42@dataclass
43class LimiterState:
44 """Current state of the rate limiter."""
46 # Current limits (None if not yet discovered)
47 request_limit: int | None = None
48 token_limit: int | None = None
50 # Usage tracking
51 requests_used: int = 0
52 tokens_used: int = 0
54 # Timing
55 last_request_time: float = field(default_factory=time.time)
56 last_limit_update: float = field(default_factory=time.time)
58 # Rate limit info from last response
59 last_rate_limit_info: RateLimitInfo | None = None
61 # Adaptive behavior
62 consecutive_rate_limit_errors: int = 0
63 adaptive_backoff_factor: float = 1.0
66class ChatLimiter:
67 """
68 A Pythonic rate limiter for API calls supporting OpenAI, Anthropic, and OpenRouter.
70 Features:
71 - Automatic rate limit discovery and adaptation
72 - Sync and async support with context managers
73 - Intelligent retry logic with exponential backoff
74 - Token and request rate limiting
75 - Provider-specific optimizations
77 Example:
78 # High-level interface (recommended)
79 async with ChatLimiter.for_model("gpt-4o", api_key="sk-...") as limiter:
80 response = await limiter.chat_completion(
81 model="gpt-4o",
82 messages=[Message(role=MessageRole.USER, content="Hello!")]
83 )
85 # Low-level interface (for advanced users)
86 async with ChatLimiter(provider=Provider.OPENAI, api_key="sk-...") as limiter:
87 response = await limiter.request("POST", "/chat/completions", json=data)
88 """
90 def __init__(
91 self,
92 provider: Provider | None = None,
93 api_key: str | None = None,
94 base_url: str | None = None,
95 config: ProviderConfig | None = None,
96 http_client: httpx.AsyncClient | None = None,
97 sync_http_client: httpx.Client | None = None,
98 enable_adaptive_limits: bool = True,
99 enable_token_estimation: bool = True,
100 request_limit: int | None = None,
101 token_limit: int | None = None,
102 max_retries: int | None = None,
103 base_backoff: float | None = None,
104 timeout: float | None = None,
105 **kwargs: Any,
106 ):
107 """
108 Initialize the ChatLimiter.
110 Args:
111 provider: The API provider (OpenAI, Anthropic, OpenRouter)
112 api_key: API key for authentication
113 base_url: Base URL for API requests
114 config: Custom provider configuration
115 http_client: Custom async HTTP client
116 sync_http_client: Custom sync HTTP client
117 enable_adaptive_limits: Enable adaptive rate limit adjustment
118 enable_token_estimation: Enable token usage estimation
119 request_limit: Override request limit (if not provided, must be discovered from API)
120 token_limit: Override token limit (if not provided, must be discovered from API)
121 max_retries: Override max retries (defaults to 3 if not provided)
122 base_backoff: Override base backoff (defaults to 1.0 if not provided)
123 timeout: HTTP request timeout in seconds (defaults to 120.0 for better reliability)
124 **kwargs: Additional arguments passed to HTTP clients
125 """
126 # Determine provider and config
127 if config:
128 self.config = config
129 self.provider = config.provider
130 elif provider:
131 self.provider = provider
132 self.config = get_provider_config(provider)
133 elif base_url:
134 detected_provider = detect_provider_from_url(base_url)
135 if detected_provider:
136 self.provider = detected_provider
137 self.config = get_provider_config(detected_provider)
138 else:
139 raise ValueError(f"Could not detect provider from URL: {base_url}")
140 else:
141 raise ValueError("Must provide either provider, config, or base_url")
143 # Override base_url if provided
144 if base_url:
145 self.config.base_url = base_url
147 # Store configuration
148 self.api_key = api_key
149 self.enable_adaptive_limits = enable_adaptive_limits
150 self.enable_token_estimation = enable_token_estimation
152 # Store user-provided overrides
153 self._user_request_limit = request_limit
154 self._user_token_limit = token_limit
155 self._user_max_retries = max_retries or 3 # Default to 3 if not provided
156 self._user_base_backoff = base_backoff or 1.0 # Default to 1.0 if not provided
157 self._user_timeout = (
158 timeout or 120.0
159 ) # Default to 120 seconds for better reliability
161 # Determine initial limits (user override, config default, or None for discovery)
162 initial_request_limit = (
163 request_limit or self.config.default_request_limit or None
164 )
165 initial_token_limit = token_limit or self.config.default_token_limit or None
167 # Initialize state - will be None if no defaults and no discovery yet
168 self.state = LimiterState(
169 request_limit=initial_request_limit,
170 token_limit=initial_token_limit,
171 )
173 # Flag to track if we need to discover limits
174 self._limits_discovered = (
175 initial_request_limit is not None and initial_token_limit is not None
176 )
178 # Initialize HTTP clients
179 self._init_http_clients(http_client, sync_http_client, **kwargs)
181 # Initialize rate limiters
182 self._init_rate_limiters()
184 # Context manager state
185 self._async_context_active = False
186 self._sync_context_active = False
188 # Verbose mode (can be set by batch processor)
189 self._verbose_mode = False
191 @classmethod
192 def for_model(
193 cls,
194 model: str,
195 api_key: str | None = None,
196 provider: str | Provider | None = None,
197 use_dynamic_discovery: bool = True,
198 request_limit: int | None = None,
199 token_limit: int | None = None,
200 max_retries: int | None = None,
201 base_backoff: float | None = None,
202 timeout: float | None = None,
203 **kwargs: Any,
204 ) -> "ChatLimiter":
205 """
206 Create a ChatLimiter instance automatically detecting the provider from the model name.
208 Args:
209 model: The model name (e.g., "gpt-4o", "claude-3-sonnet-20240229")
210 api_key: API key for the provider. If None, will be read from environment variables
211 (OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY)
212 provider: Override provider detection. Can be "openai", "anthropic", "openrouter",
213 or Provider enum. If None, will be auto-detected from model name
214 use_dynamic_discovery: Whether to query live APIs for model availability (default: True).
215 Requires appropriate API keys to be available. Falls back to
216 hardcoded model lists when disabled or when API calls fail.
217 **kwargs: Additional arguments passed to ChatLimiter
219 Returns:
220 Configured ChatLimiter instance
222 Raises:
223 ValueError: If provider cannot be determined from model name or API key not found
225 Example:
226 # Auto-detect provider with dynamic discovery (default behavior)
227 async with ChatLimiter.for_model("gpt-4o") as limiter:
228 response = await limiter.simple_chat("gpt-4o", "Hello!")
230 # Override provider detection
231 async with ChatLimiter.for_model("custom-model", provider="openai") as limiter:
232 response = await limiter.simple_chat("custom-model", "Hello!")
234 # Disable dynamic discovery to use only hardcoded model lists
235 async with ChatLimiter.for_model("gpt-4o", use_dynamic_discovery=False) as limiter:
236 response = await limiter.simple_chat("gpt-4o", "Hello!")
237 """
238 import os
240 # Determine provider
241 if provider is not None:
242 # Use provided provider
243 if isinstance(provider, str):
244 provider_enum = Provider(provider)
245 else:
246 provider_enum = provider
247 provider_name = provider_enum.value
248 else:
249 # Auto-detect from model name
250 # If dynamic discovery is requested, we need to collect API keys first
251 api_keys_for_discovery = {}
252 if use_dynamic_discovery:
253 # Collect available API keys from environment
254 env_var_map = {
255 "openai": "OPENAI_API_KEY",
256 "anthropic": "ANTHROPIC_API_KEY",
257 "openrouter": "OPENROUTER_API_KEY",
258 }
260 for provider_key, env_var in env_var_map.items():
261 key_value = os.getenv(env_var)
262 if key_value:
263 api_keys_for_discovery[provider_key] = key_value
265 # Try dynamic discovery first to get more detailed information
266 discovery_result = None
267 if use_dynamic_discovery and api_keys_for_discovery:
268 from .models import detect_provider_from_model_sync
270 discovery_result = detect_provider_from_model_sync(
271 model, api_keys_for_discovery
272 )
273 detected_provider = discovery_result.found_provider
274 else:
275 detected_provider = detect_provider_from_model(
276 model, use_dynamic_discovery, api_keys_for_discovery
277 )
279 if not detected_provider:
280 discovery_msg = (
281 " with dynamic API discovery" if use_dynamic_discovery else ""
282 )
283 error_msg = f"Could not determine provider from model '{model}'{discovery_msg}. "
285 # Add detailed information about available models if we have discovery results
286 if discovery_result and discovery_result.get_total_models_found() > 0:
287 error_msg += f"\n\nFound {discovery_result.get_total_models_found()} models across providers:\n"
288 for (
289 provider_name,
290 models,
291 ) in discovery_result.get_all_models().items():
292 error_msg += f" {provider_name}: {len(models)} models\n"
293 for example in sorted(list(models)):
294 error_msg += f" - {example}\n"
295 error_msg += "\nPlease check the model name or specify the provider explicitly using the 'provider' parameter."
296 else:
297 error_msg += "Please specify the provider explicitly using the 'provider' parameter."
299 # Add information about discovery errors if any
300 if discovery_result and discovery_result.errors:
301 error_msg += f"\n\nDiscovery errors encountered:\n"
302 for provider_name, error in discovery_result.errors.items():
303 error_msg += f" {provider_name}: {error}\n"
305 raise ValueError(error_msg)
306 assert detected_provider is not None # Help MyPy understand type narrowing
307 provider_name = detected_provider
308 provider_enum = Provider(provider_name)
310 # Determine API key
311 if api_key is None:
312 # Try to get from environment variables
313 env_var_map = {
314 "openai": "OPENAI_API_KEY",
315 "anthropic": "ANTHROPIC_API_KEY",
316 "openrouter": "OPENROUTER_API_KEY",
317 }
319 env_var_name: str | None = env_var_map.get(provider_name)
320 if env_var_name:
321 api_key = os.getenv(env_var_name)
322 if not api_key:
323 raise ValueError(
324 f"API key not provided and {env_var_name} environment variable not set. "
325 f"Please provide api_key parameter or set {env_var_name} environment variable."
326 )
327 else:
328 raise ValueError(
329 f"Unknown provider '{provider_name}'. Cannot determine environment variable for API key."
330 )
332 return cls(
333 provider=provider_enum,
334 api_key=api_key,
335 request_limit=request_limit,
336 token_limit=token_limit,
337 max_retries=max_retries,
338 base_backoff=base_backoff,
339 timeout=timeout,
340 **kwargs,
341 )
343 def _init_http_clients(
344 self,
345 http_client: httpx.AsyncClient | None,
346 sync_http_client: httpx.Client | None,
347 **kwargs: Any,
348 ) -> None:
349 """Initialize HTTP clients with proper headers."""
350 # Prepare headers
351 headers = {
352 "User-Agent": f"chat-limiter/0.1.0 ({self.provider.value})",
353 }
355 # Add provider-specific headers
356 if self.api_key:
357 if self.provider == Provider.OPENAI:
358 headers["Authorization"] = f"Bearer {self.api_key}"
359 elif self.provider == Provider.ANTHROPIC:
360 headers["x-api-key"] = self.api_key
361 headers["anthropic-version"] = "2023-06-01"
362 elif self.provider == Provider.OPENROUTER:
363 headers["Authorization"] = f"Bearer {self.api_key}"
364 headers["HTTP-Referer"] = "https://github.com/your-repo/chat-limiter"
366 # Merge with user-provided headers
367 if "headers" in kwargs:
368 headers.update(kwargs["headers"])
369 kwargs["headers"] = headers
371 # Initialize clients
372 if http_client:
373 self.async_client = http_client
374 else:
375 self.async_client = httpx.AsyncClient(
376 base_url=self.config.base_url,
377 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout
378 **kwargs,
379 )
381 if sync_http_client:
382 self.sync_client = sync_http_client
383 else:
384 self.sync_client = httpx.Client(
385 base_url=self.config.base_url,
386 timeout=httpx.Timeout(self._user_timeout), # Configurable timeout
387 **kwargs,
388 )
390 def _init_rate_limiters(self) -> None:
391 """Initialize PyrateLimiter instances."""
392 # Only initialize if we have limits
393 if self.state.request_limit is None or self.state.token_limit is None:
394 # Cannot initialize rate limiters without limits
395 # This will be called again after limits are discovered
396 self.request_limiter = None
397 self.token_limiter = None
398 self._effective_request_limit = None
399 self._effective_token_limit = None
400 return
402 # Calculate effective limits with buffer
403 effective_request_limit = int(
404 self.state.request_limit * self.config.request_buffer_ratio
405 )
406 effective_token_limit = int(
407 self.state.token_limit * self.config.token_buffer_ratio
408 )
410 # Request rate limiter
411 self.request_limiter = Limiter(
412 Rate(
413 effective_request_limit,
414 Duration.MINUTE,
415 )
416 )
418 # Token rate limiter
419 self.token_limiter = Limiter(
420 Rate(
421 effective_token_limit,
422 Duration.MINUTE,
423 )
424 )
426 # Store effective limits for logging
427 self._effective_request_limit = effective_request_limit
428 self._effective_token_limit = effective_token_limit
430 async def __aenter__(self) -> "ChatLimiter":
431 """Async context manager entry."""
432 if self._async_context_active:
433 raise RuntimeError(
434 "ChatLimiter is already active as an async context manager"
435 )
437 self._async_context_active = True
439 # Discover rate limits if supported
440 if self.config.supports_dynamic_limits:
441 await self._discover_rate_limits()
443 # Print rate limit information if verbose mode is enabled
444 if self._verbose_mode:
445 self._print_rate_limit_info()
447 return self
449 async def __aexit__(
450 self,
451 exc_type: type[BaseException] | None,
452 exc_val: BaseException | None,
453 exc_tb: object,
454 ) -> None:
455 """Async context manager exit."""
456 self._async_context_active = False
457 await self.async_client.aclose()
459 def __enter__(self) -> "ChatLimiter":
460 """Sync context manager entry."""
461 if self._sync_context_active:
462 raise RuntimeError(
463 "ChatLimiter is already active as a sync context manager"
464 )
466 self._sync_context_active = True
468 # Discover rate limits if supported
469 if self.config.supports_dynamic_limits:
470 self._discover_rate_limits_sync()
472 # Print rate limit information if verbose mode is enabled
473 if self._verbose_mode:
474 self._print_rate_limit_info()
476 return self
478 def __exit__(
479 self,
480 exc_type: type[BaseException] | None,
481 exc_val: BaseException | None,
482 exc_tb: object,
483 ) -> None:
484 """Sync context manager exit."""
485 self._sync_context_active = False
486 self.sync_client.close()
488 async def _discover_rate_limits(self) -> None:
489 """Discover current rate limits from the API."""
490 try:
491 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint:
492 # OpenRouter uses a special auth endpoint
493 response = await self.async_client.get(self.config.auth_endpoint)
494 response.raise_for_status()
496 data = response.json()
497 # Update limits based on response
498 # This is a simplified version - actual implementation would parse the response
499 logger.info(f"Discovered OpenRouter limits: {data}")
501 else:
502 # For other providers, we'll discover limits on first request
503 if self._verbose_mode:
504 print(
505 f"Rate limit discovery will happen on first request for {self.provider.value}"
506 )
507 logger.info(
508 f"Rate limit discovery will happen on first request for {self.provider.value}"
509 )
511 except Exception as e:
512 logger.warning(f"Failed to discover rate limits: {e}")
514 def _discover_rate_limits_sync(self) -> None:
515 """Sync version of rate limit discovery."""
516 try:
517 if self.provider == Provider.OPENROUTER and self.config.auth_endpoint:
518 response = self.sync_client.get(self.config.auth_endpoint)
519 response.raise_for_status()
521 data = response.json()
522 logger.info(f"Discovered OpenRouter limits: {data}")
523 else:
524 logger.info(
525 f"Rate limit discovery will happen on first request for {self.provider.value}"
526 )
528 except Exception as e:
529 logger.warning(f"Failed to discover rate limits: {e}")
531 def _update_rate_limits(self, rate_limit_info: RateLimitInfo) -> None:
532 """Update rate limits based on response headers."""
533 updated = False
534 was_uninitialized = (
535 self.state.request_limit is None or self.state.token_limit is None
536 )
538 # Update request limits
539 if (
540 rate_limit_info.requests_limit
541 and rate_limit_info.requests_limit != self.state.request_limit
542 ):
543 old_limit = self.state.request_limit
544 self.state.request_limit = rate_limit_info.requests_limit
545 updated = True
546 if was_uninitialized:
547 message = (
548 f"Discovered request limit: {self.state.request_limit} req/min"
549 )
550 if self._verbose_mode:
551 print(message)
552 logger.info(message)
553 else:
554 message = f"Updated request limit: {old_limit} -> {self.state.request_limit} req/min"
555 if self._verbose_mode:
556 print(message)
557 logger.info(message)
559 # Update token limits
560 if (
561 rate_limit_info.tokens_limit
562 and rate_limit_info.tokens_limit != self.state.token_limit
563 ):
564 old_limit = self.state.token_limit
565 self.state.token_limit = rate_limit_info.tokens_limit
566 updated = True
567 if was_uninitialized:
568 message = f"Discovered token limit: {self.state.token_limit} tokens/min"
569 if self._verbose_mode:
570 print(message)
571 logger.info(message)
572 else:
573 message = f"Updated token limit: {old_limit} -> {self.state.token_limit} tokens/min"
574 if self._verbose_mode:
575 print(message)
576 logger.info(message)
578 if updated:
579 # Reinitialize rate limiters with new limits
580 self._init_rate_limiters()
582 # Update limits_discovered flag if both limits are now available
583 if (
584 self.state.request_limit is not None
585 and self.state.token_limit is not None
586 ):
587 self._limits_discovered = True
589 if was_uninitialized:
590 message = "Rate limiters initialized after discovery"
591 if self._verbose_mode:
592 print(message)
593 # Print updated rate limit info after discovery
594 self._print_rate_limit_info()
595 logger.info(message)
597 # Store the rate limit info
598 self.state.last_rate_limit_info = rate_limit_info
599 self.state.last_limit_update = time.time()
601 def _estimate_tokens(self, request_data: dict[str, Any]) -> int:
602 """Estimate token usage from request data."""
603 if not self.enable_token_estimation:
604 return 0
606 # Simple token estimation
607 # This is a placeholder - real implementation would use tiktoken or similar
608 if "messages" in request_data:
609 text = ""
610 for message in request_data["messages"]:
611 if isinstance(message, dict) and "content" in message:
612 text += str(message["content"])
614 # Rough estimation: 1 token ≈ 4 characters
615 return len(text) // 4
617 return 0
619 @asynccontextmanager
620 async def _acquire_rate_limits(
621 self, estimated_tokens: int = 0
622 ) -> AsyncIterator[None]:
623 """Acquire rate limits before making a request."""
624 # Check if rate limiters are initialized
625 if self.request_limiter is None or self.token_limiter is None:
626 # Limits not yet discovered - this request will help discover them
627 logger.info(
628 "Rate limits not yet discovered, proceeding without rate limiting for discovery"
629 )
630 else:
631 # Wait for request rate limit
632 await asyncio.to_thread(self.request_limiter.try_acquire, "request")
634 # Wait for token rate limit if we have token estimation and limiters are initialized
635 if (
636 estimated_tokens > 0
637 and self.token_limiter is not None
638 and self._effective_token_limit is not None
639 ):
640 # Check if request is too large for bucket capacity
641 if estimated_tokens > self._effective_token_limit:
642 # Log warning for large requests
643 logger.warning(
644 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity "
645 f"of {self._effective_token_limit} tokens. This may cause delays."
646 )
647 # For very large requests, we'll split the acquisition
648 # Acquire tokens in chunks to avoid bucket overflow
649 remaining_tokens = estimated_tokens
650 while remaining_tokens > 0:
651 chunk_size = min(
652 remaining_tokens, self._effective_token_limit // 2
653 )
654 await asyncio.to_thread(
655 self.token_limiter.try_acquire, "token", chunk_size
656 )
657 remaining_tokens -= chunk_size
658 if remaining_tokens > 0:
659 # Brief pause to let bucket refill
660 await asyncio.sleep(0.1)
661 else:
662 # Normal acquisition for smaller requests
663 await asyncio.to_thread(
664 self.token_limiter.try_acquire, "token", estimated_tokens
665 )
667 try:
668 yield
669 finally:
670 # Update usage tracking
671 self.state.requests_used += 1
672 self.state.tokens_used += estimated_tokens
673 self.state.last_request_time = time.time()
675 @contextmanager
676 def _acquire_rate_limits_sync(self, estimated_tokens: int = 0) -> Iterator[None]:
677 """Sync version of rate limit acquisition."""
678 # Check if rate limiters are initialized
679 if self.request_limiter is None or self.token_limiter is None:
680 # Limits not yet discovered - this request will help discover them
681 logger.info(
682 "Rate limits not yet discovered, proceeding without rate limiting for discovery"
683 )
684 else:
685 # Wait for request rate limit
686 self.request_limiter.try_acquire("request")
688 # Wait for token rate limit if we have token estimation and limiters are initialized
689 if (
690 estimated_tokens > 0
691 and self.token_limiter is not None
692 and self._effective_token_limit is not None
693 ):
694 # Check if request is too large for bucket capacity
695 if estimated_tokens > self._effective_token_limit:
696 # Log warning for large requests
697 logger.warning(
698 f"Request estimated at {estimated_tokens} tokens exceeds bucket capacity "
699 f"of {self._effective_token_limit} tokens. This may cause delays."
700 )
701 # For very large requests, we'll split the acquisition
702 # Acquire tokens in chunks to avoid bucket overflow
703 remaining_tokens = estimated_tokens
704 while remaining_tokens > 0:
705 chunk_size = min(
706 remaining_tokens, self._effective_token_limit // 2
707 )
708 self.token_limiter.try_acquire("token", chunk_size)
709 remaining_tokens -= chunk_size
710 if remaining_tokens > 0:
711 # Brief pause to let bucket refill
712 time.sleep(0.1)
713 else:
714 # Normal acquisition for smaller requests
715 self.token_limiter.try_acquire("token", estimated_tokens)
717 try:
718 yield
719 finally:
720 # Update usage tracking
721 self.state.requests_used += 1
722 self.state.tokens_used += estimated_tokens
723 self.state.last_request_time = time.time()
725 def _get_retry_decorator(self):
726 """Get retry decorator with user-configured parameters."""
727 return retry(
728 stop=stop_after_attempt(self._user_max_retries),
729 wait=wait_exponential(multiplier=self._user_base_backoff, min=1, max=60),
730 retry=retry_if_exception_type(
731 (
732 httpx.HTTPStatusError,
733 httpx.RequestError,
734 httpx.ReadTimeout,
735 httpx.ConnectTimeout,
736 )
737 ),
738 )
740 async def request(
741 self,
742 method: str,
743 url: str,
744 *,
745 json: dict[str, Any] | None = None,
746 **kwargs: Any,
747 ) -> httpx.Response:
748 """Wrapper that applies retry decorator dynamically."""
749 try:
750 return await self._get_retry_decorator()(self._request_impl)(
751 method, url, json=json, **kwargs
752 )
753 except Exception as e:
754 # Check if this is a retry error wrapping a timeout
755 if (
756 hasattr(e, "last_attempt")
757 and e.last_attempt
758 and e.last_attempt.exception()
759 ):
760 original_exception = e.last_attempt.exception()
761 if isinstance(
762 original_exception, (httpx.ReadTimeout, httpx.ConnectTimeout)
763 ):
764 # Enhance timeout error with helpful information
765 timeout_info = (
766 f"\n💡 Timeout Error Help:\n"
767 f" Current timeout: {self._user_timeout}s\n"
768 f" To increase timeout, use: ChatLimiter.for_model('{self.provider.value}', timeout={int(self._user_timeout + 60)})\n"
769 f" Or reduce batch concurrency if processing multiple requests\n"
770 f" Retries attempted: {self._user_max_retries}\n"
771 )
772 raise type(original_exception)(
773 str(original_exception) + timeout_info
774 ) from e
776 # For direct timeout errors (shouldn't happen due to retry decorator but just in case)
777 if isinstance(e, (httpx.ReadTimeout, httpx.ConnectTimeout)):
778 timeout_info = (
779 f"\n💡 Timeout Error Help:\n"
780 f" Current timeout: {self._user_timeout}s\n"
781 f" To increase timeout, use: ChatLimiter.for_model('{self.provider.value}', timeout={int(self._user_timeout + 60)})\n"
782 f" Or reduce batch concurrency if processing multiple requests\n"
783 )
784 raise type(e)(str(e) + timeout_info) from e
786 # Re-raise any other exceptions unchanged
787 raise
789 async def _request_impl(
790 self,
791 method: str,
792 url: str,
793 *,
794 json: dict[str, Any] | None = None,
795 **kwargs: Any,
796 ) -> httpx.Response:
797 """
798 Make an async HTTP request with rate limiting.
800 Args:
801 method: HTTP method (GET, POST, etc.)
802 url: URL or path for the request
803 json: JSON data to send
804 **kwargs: Additional arguments passed to httpx
806 Returns:
807 HTTP response
809 Raises:
810 httpx.HTTPStatusError: For HTTP error responses
811 httpx.RequestError: For request errors
812 """
813 if not self._async_context_active:
814 raise RuntimeError("ChatLimiter must be used as an async context manager")
816 # Estimate tokens if we have JSON data
817 estimated_tokens = self._estimate_tokens(json or {})
819 # Acquire rate limits
820 async with self._acquire_rate_limits(estimated_tokens):
821 # Make the request
822 response = await self.async_client.request(method, url, json=json, **kwargs)
824 # Extract rate limit info
825 rate_limit_info = extract_rate_limit_info(
826 dict(response.headers), self.config
827 )
829 # Update our rate limits
830 if self.enable_adaptive_limits:
831 self._update_rate_limits(rate_limit_info)
833 # Handle rate limit errors
834 if response.status_code == 429:
835 self.state.consecutive_rate_limit_errors += 1
836 if rate_limit_info.retry_after:
837 await asyncio.sleep(rate_limit_info.retry_after)
838 else:
839 # Exponential backoff
840 backoff = self.config.base_backoff * (
841 2**self.state.consecutive_rate_limit_errors
842 )
843 await asyncio.sleep(min(backoff, self.config.max_backoff))
845 response.raise_for_status()
846 else:
847 # Reset consecutive errors on success
848 self.state.consecutive_rate_limit_errors = 0
850 return response
852 def request_sync(
853 self,
854 method: str,
855 url: str,
856 *,
857 json: dict[str, Any] | None = None,
858 **kwargs: Any,
859 ) -> httpx.Response:
860 """Wrapper that applies retry decorator dynamically."""
861 # For sync, we need to use the sync version of retry
862 retry_decorator = retry(
863 stop=stop_after_attempt(self._user_max_retries),
864 wait=wait_exponential(multiplier=self._user_base_backoff, min=1, max=60),
865 retry=retry_if_exception_type(
866 (
867 httpx.HTTPStatusError,
868 httpx.RequestError,
869 httpx.ReadTimeout,
870 httpx.ConnectTimeout,
871 )
872 ),
873 )
874 try:
875 return retry_decorator(self._request_sync_impl)(
876 method, url, json=json, **kwargs
877 )
878 except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
879 # Enhance timeout error with helpful information
880 timeout_info = (
881 f"\n💡 Timeout Error Help:\n"
882 f" Current timeout: {self._user_timeout}s\n"
883 f" To increase timeout, use: ChatLimiter.for_model('{self.provider.value}', timeout={int(self._user_timeout + 60)})\n"
884 f" Or reduce batch concurrency if processing multiple requests\n"
885 )
886 raise type(e)(str(e) + timeout_info) from e
888 def _request_sync_impl(
889 self,
890 method: str,
891 url: str,
892 *,
893 json: dict[str, Any] | None = None,
894 **kwargs: Any,
895 ) -> httpx.Response:
896 """
897 Make a sync HTTP request with rate limiting.
899 Args:
900 method: HTTP method (GET, POST, etc.)
901 url: URL or path for the request
902 json: JSON data to send
903 **kwargs: Additional arguments passed to httpx
905 Returns:
906 HTTP response
908 Raises:
909 httpx.HTTPStatusError: For HTTP error responses
910 httpx.RequestError: For request errors
911 """
912 if not self._sync_context_active:
913 raise RuntimeError("ChatLimiter must be used as a sync context manager")
915 # Estimate tokens if we have JSON data
916 estimated_tokens = self._estimate_tokens(json or {})
918 # Acquire rate limits
919 with self._acquire_rate_limits_sync(estimated_tokens):
920 # Make the request
921 response = self.sync_client.request(method, url, json=json, **kwargs)
923 # Extract rate limit info
924 rate_limit_info = extract_rate_limit_info(
925 dict(response.headers), self.config
926 )
928 # Update our rate limits
929 if self.enable_adaptive_limits:
930 self._update_rate_limits(rate_limit_info)
932 # Handle rate limit errors
933 if response.status_code == 429:
934 self.state.consecutive_rate_limit_errors += 1
935 if rate_limit_info.retry_after:
936 time.sleep(rate_limit_info.retry_after)
937 else:
938 # Exponential backoff
939 backoff = self.config.base_backoff * (
940 2**self.state.consecutive_rate_limit_errors
941 )
942 time.sleep(min(backoff, self.config.max_backoff))
944 response.raise_for_status()
945 else:
946 # Reset consecutive errors on success
947 self.state.consecutive_rate_limit_errors = 0
949 return response
951 def get_current_limits(self) -> dict[str, Any]:
952 """Get current rate limit information."""
953 return {
954 "provider": self.provider.value,
955 "request_limit": self.state.request_limit,
956 "token_limit": self.state.token_limit,
957 "requests_used": self.state.requests_used,
958 "tokens_used": self.state.tokens_used,
959 "last_request_time": self.state.last_request_time,
960 "last_limit_update": self.state.last_limit_update,
961 "consecutive_rate_limit_errors": self.state.consecutive_rate_limit_errors,
962 }
964 def reset_usage_tracking(self) -> None:
965 """Reset usage tracking counters."""
966 self.state.requests_used = 0
967 self.state.tokens_used = 0
968 self.state.consecutive_rate_limit_errors = 0
970 # High-level chat completion methods
972 async def chat_completion(
973 self,
974 model: str,
975 messages: list[Message],
976 max_tokens: int | None = None,
977 temperature: float | None = None,
978 top_p: float | None = None,
979 stop: str | list[str] | None = None,
980 stream: bool = False,
981 **kwargs: Any,
982 ) -> ChatCompletionResponse:
983 """
984 Make a high-level chat completion request.
986 Args:
987 model: The model to use for completion
988 messages: List of messages in the conversation
989 max_tokens: Maximum tokens to generate
990 temperature: Sampling temperature
991 top_p: Top-p sampling parameter
992 stop: Stop sequences
993 stream: Whether to stream the response
994 **kwargs: Additional provider-specific parameters
996 Returns:
997 ChatCompletionResponse with the completion result
999 Raises:
1000 ValueError: If provider cannot be determined from model
1001 httpx.HTTPStatusError: For HTTP error responses
1002 httpx.RequestError: For request errors
1003 """
1004 if not self._async_context_active:
1005 raise RuntimeError("ChatLimiter must be used as an async context manager")
1007 # Create request object
1008 request = ChatCompletionRequest(
1009 model=model,
1010 messages=messages,
1011 max_tokens=max_tokens,
1012 temperature=temperature,
1013 top_p=top_p,
1014 stop=stop,
1015 stream=stream,
1016 **kwargs,
1017 )
1019 # Get the appropriate adapter
1020 adapter = get_adapter(self.provider)
1022 # Format the request for the provider
1023 formatted_request = adapter.format_request(request)
1025 # Make the HTTP request
1026 response = await self.request(
1027 "POST", adapter.get_endpoint(), json=formatted_request
1028 )
1030 # Parse the response
1031 response_data = response.json()
1032 return adapter.parse_response(response_data, request)
1034 def chat_completion_sync(
1035 self,
1036 model: str,
1037 messages: list[Message],
1038 max_tokens: int | None = None,
1039 temperature: float | None = None,
1040 top_p: float | None = None,
1041 stop: str | list[str] | None = None,
1042 stream: bool = False,
1043 **kwargs: Any,
1044 ) -> ChatCompletionResponse:
1045 """
1046 Make a synchronous high-level chat completion request.
1048 Args:
1049 model: The model to use for completion
1050 messages: List of messages in the conversation
1051 max_tokens: Maximum tokens to generate
1052 temperature: Sampling temperature
1053 top_p: Top-p sampling parameter
1054 stop: Stop sequences
1055 stream: Whether to stream the response
1056 **kwargs: Additional provider-specific parameters
1058 Returns:
1059 ChatCompletionResponse with the completion result
1061 Raises:
1062 ValueError: If provider cannot be determined from model
1063 httpx.HTTPStatusError: For HTTP error responses
1064 httpx.RequestError: For request errors
1065 """
1066 if not self._sync_context_active:
1067 raise RuntimeError("ChatLimiter must be used as a sync context manager")
1069 # Create request object
1070 request = ChatCompletionRequest(
1071 model=model,
1072 messages=messages,
1073 max_tokens=max_tokens,
1074 temperature=temperature,
1075 top_p=top_p,
1076 stop=stop,
1077 stream=stream,
1078 **kwargs,
1079 )
1081 # Get the appropriate adapter
1082 adapter = get_adapter(self.provider)
1084 # Format the request for the provider
1085 formatted_request = adapter.format_request(request)
1087 # Make the HTTP request
1088 response = self.request_sync(
1089 "POST", adapter.get_endpoint(), json=formatted_request
1090 )
1092 # Parse the response
1093 response_data = response.json()
1094 return adapter.parse_response(response_data, request)
1096 # Convenience methods for different message types
1098 async def simple_chat(
1099 self,
1100 model: str,
1101 prompt: str,
1102 max_tokens: int | None = None,
1103 temperature: float | None = None,
1104 **kwargs: Any,
1105 ) -> str:
1106 """
1107 Simple chat completion that returns just the text response.
1109 Args:
1110 model: The model to use
1111 prompt: The user prompt
1112 max_tokens: Maximum tokens to generate
1113 temperature: Sampling temperature
1114 **kwargs: Additional parameters
1116 Returns:
1117 The text response from the model
1118 """
1119 messages = [Message(role=MessageRole.USER, content=prompt)]
1120 response = await self.chat_completion(
1121 model=model,
1122 messages=messages,
1123 max_tokens=max_tokens,
1124 temperature=temperature,
1125 **kwargs,
1126 )
1128 if response.choices:
1129 return response.choices[0].message.content
1130 return ""
1132 def simple_chat_sync(
1133 self,
1134 model: str,
1135 prompt: str,
1136 max_tokens: int | None = None,
1137 temperature: float | None = None,
1138 **kwargs: Any,
1139 ) -> str:
1140 """
1141 Simple synchronous chat completion that returns just the text response.
1143 Args:
1144 model: The model to use
1145 prompt: The user prompt
1146 max_tokens: Maximum tokens to generate
1147 temperature: Sampling temperature
1148 **kwargs: Additional parameters
1150 Returns:
1151 The text response from the model
1152 """
1153 messages = [Message(role=MessageRole.USER, content=prompt)]
1154 response = self.chat_completion_sync(
1155 model=model,
1156 messages=messages,
1157 max_tokens=max_tokens,
1158 temperature=temperature,
1159 **kwargs,
1160 )
1162 if response.choices:
1163 return response.choices[0].message.content
1164 return ""
1166 def set_verbose_mode(self, verbose: bool) -> None:
1167 """Set verbose mode for detailed logging."""
1168 self._verbose_mode = verbose
1170 def _print_rate_limit_info(self) -> None:
1171 """Print current rate limit configuration."""
1172 print(f"\n=== Rate Limit Configuration for {self.provider.value.title()} ===")
1173 print(f"Provider: {self.provider.value}")
1174 print(f"Base URL: {self.config.base_url}")
1176 # Handle None values for limits
1177 if self.state.request_limit is not None:
1178 effective_req = self._effective_request_limit or "not calculated"
1179 print(
1180 f"Request Limit: {self.state.request_limit}/minute (effective: {effective_req}/minute)"
1181 )
1182 else:
1183 print("Request Limit: Not yet discovered (will be fetched from API)")
1185 if self.state.token_limit is not None:
1186 effective_tok = self._effective_token_limit or "not calculated"
1187 print(
1188 f"Token Limit: {self.state.token_limit}/minute (effective: {effective_tok}/minute)"
1189 )
1190 else:
1191 print("Token Limit: Not yet discovered (will be fetched from API)")
1193 print(f"Request Buffer Ratio: {self.config.request_buffer_ratio}")
1194 print(f"Token Buffer Ratio: {self.config.token_buffer_ratio}")
1195 print(f"Adaptive Limits: {self.enable_adaptive_limits}")
1196 print(f"Token Estimation: {self.enable_token_estimation}")
1197 print(f"Dynamic Discovery: {self.config.supports_dynamic_limits}")
1198 print(f"Limits Discovered: {self._limits_discovered}")
1199 print("=" * 50)