Coverage for src/chat_limiter/types.py: 94%

62 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 20:26 +0100

1""" 

2Type definitions for chat completion requests and responses. 

3""" 

4 

5from dataclasses import dataclass 

6from enum import Enum 

7from typing import Any 

8 

9from pydantic import BaseModel 

10 

11 

12class MessageRole(str, Enum): 

13 """Message roles supported across providers.""" 

14 

15 USER = "user" 

16 ASSISTANT = "assistant" 

17 SYSTEM = "system" 

18 

19 

20@dataclass 

21class Message: 

22 """A chat message that works across all providers.""" 

23 

24 role: MessageRole 

25 content: str 

26 

27 

28class ChatCompletionRequest(BaseModel): 

29 """High-level chat completion request.""" 

30 

31 model: str 

32 messages: list[Message] 

33 max_tokens: int | None = None 

34 temperature: float | None = None 

35 top_p: float | None = None 

36 stop: str | list[str] | None = None 

37 stream: bool = False 

38 

39 # Provider-specific parameters (will be filtered per provider) 

40 frequency_penalty: float | None = None # OpenAI 

41 presence_penalty: float | None = None # OpenAI 

42 top_k: int | None = None # Anthropic 

43 

44 

45@dataclass 

46class Usage: 

47 """Token usage information.""" 

48 

49 prompt_tokens: int 

50 completion_tokens: int 

51 total_tokens: int 

52 

53 

54@dataclass 

55class Choice: 

56 """A completion choice.""" 

57 

58 index: int 

59 message: Message 

60 finish_reason: str | None = None 

61 

62 

63@dataclass 

64class ChatCompletionResponse: 

65 """High-level chat completion response.""" 

66 

67 id: str 

68 model: str 

69 choices: list[Choice] 

70 usage: Usage | None = None 

71 created: int | None = None 

72 

73 # Error information 

74 success: bool = True 

75 error_message: str | None = None 

76 

77 # Provider-specific metadata 

78 provider: str | None = None 

79 raw_response: dict[str, Any] | None = None 

80 

81 

82# Model mappings for each provider 

83OPENAI_MODELS = { 

84 "gpt-4o", 

85 "gpt-4o-mini", 

86 "gpt-4-turbo", 

87 "gpt-4", 

88 "gpt-3.5-turbo", 

89 "gpt-3.5-turbo-16k", 

90} 

91 

92ANTHROPIC_MODELS = { 

93 "claude-3-5-sonnet-20241022", 

94 "claude-3-5-haiku-20241022", 

95 "claude-3-opus-20240229", 

96 "claude-3-sonnet-20240229", 

97 "claude-3-haiku-20240307", 

98} 

99 

100OPENROUTER_MODELS = { 

101 # OpenAI models via OpenRouter 

102 "openai/gpt-4o", 

103 "openai/gpt-4o-mini", 

104 "openai/gpt-4-turbo", 

105 "openai/gpt-3.5-turbo", 

106 

107 # Anthropic models via OpenRouter 

108 "anthropic/claude-3-5-sonnet", 

109 "anthropic/claude-3-opus", 

110 "anthropic/claude-3-sonnet", 

111 "anthropic/claude-3-haiku", 

112 

113 # Other providers via OpenRouter 

114 "meta-llama/llama-3.1-405b-instruct", 

115 "meta-llama/llama-3.1-70b-instruct", 

116 "google/gemini-pro", 

117 "cohere/command-r-plus", 

118} 

119 

120ALL_MODELS = OPENAI_MODELS | ANTHROPIC_MODELS | OPENROUTER_MODELS 

121 

122 

123def detect_provider_from_model(model: str, use_dynamic_discovery: bool = False, api_keys: dict[str, str] | None = None) -> str | None: 

124 """ 

125 Detect provider from model name. 

126 

127 Args: 

128 model: The model name to check 

129 use_dynamic_discovery: Whether to use live API queries for model discovery 

130 api_keys: Dictionary of API keys for dynamic discovery 

131 

132 Returns: 

133 Provider name or None if not found 

134 """ 

135 # First try pattern-based detection for common cases 

136 if "/" in model: # OpenRouter format 

137 return "openrouter" 

138 

139 # Check hardcoded lists for fast lookup 

140 if model in OPENAI_MODELS: 

141 return "openai" 

142 elif model in ANTHROPIC_MODELS: 

143 return "anthropic" 

144 elif model in OPENROUTER_MODELS: 

145 return "openrouter" 

146 

147 # If dynamic discovery is enabled and we have API keys, try that 

148 if use_dynamic_discovery and api_keys: 

149 from .models import detect_provider_from_model_sync 

150 result = detect_provider_from_model_sync(model, api_keys) 

151 return result.found_provider 

152 

153 return None