Coverage for mcpgateway/handlers/sampling.py: 77%

87 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 16:34 +0100

1# -*- coding: utf-8 -*- 

2"""MCP Sampling Handler Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements the sampling handler for MCP LLM interactions. 

9It handles model selection, sampling preferences, and message generation. 

10""" 

11 

12import logging 

13from typing import Any, Dict, List 

14 

15from sqlalchemy.orm import Session 

16 

17from mcpgateway.types import CreateMessageResult, ModelPreferences, Role, TextContent 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class SamplingError(Exception): 

23 """Base class for sampling errors.""" 

24 

25 

26class SamplingHandler: 

27 """MCP sampling request handler. 

28 

29 Handles: 

30 - Model selection based on preferences 

31 - Message sampling requests 

32 - Context management 

33 - Content validation 

34 """ 

35 

36 def __init__(self): 

37 """Initialize sampling handler.""" 

38 self._supported_models = { 

39 # Maps model names to capabilities scores (cost, speed, intelligence) 

40 "claude-3-haiku": (0.8, 0.9, 0.7), 

41 "claude-3-sonnet": (0.5, 0.7, 0.9), 

42 "claude-3-opus": (0.2, 0.5, 1.0), 

43 "gemini-1.5-pro": (0.6, 0.8, 0.8), 

44 } 

45 

46 async def initialize(self) -> None: 

47 """Initialize sampling handler.""" 

48 logger.info("Initializing sampling handler") 

49 

50 async def shutdown(self) -> None: 

51 """Shutdown sampling handler.""" 

52 logger.info("Shutting down sampling handler") 

53 

54 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult: 

55 """Create message from sampling request. 

56 

57 Args: 

58 db: Database session 

59 request: Sampling request parameters 

60 

61 Returns: 

62 Sampled message result 

63 

64 Raises: 

65 SamplingError: If sampling fails 

66 """ 

67 try: 

68 # Extract request parameters 

69 messages = request.get("messages", []) 

70 max_tokens = request.get("maxTokens") 

71 model_prefs = ModelPreferences.parse_obj(request.get("modelPreferences", {})) 

72 include_context = request.get("includeContext", "none") 

73 request.get("metadata", {}) 

74 

75 # Validate request 

76 if not messages: 

77 raise SamplingError("No messages provided") 

78 if not max_tokens: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 raise SamplingError("Max tokens not specified") 

80 

81 # Select model 

82 model = self._select_model(model_prefs) 

83 logger.info(f"Selected model: {model}") 

84 

85 # Include context if requested 

86 if include_context != "none": 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 messages = await self._add_context(db, messages, include_context) 

88 

89 # Validate messages 

90 for msg in messages: 

91 if not self._validate_message(msg): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 raise SamplingError(f"Invalid message format: {msg}") 

93 

94 # TODO: Sample from selected model 

95 # For now return mock response 

96 response = self._mock_sample(messages=messages) 

97 

98 # Convert to result 

99 return CreateMessageResult( 

100 content=TextContent(type="text", text=response), 

101 model=model, 

102 role=Role.ASSISTANT, 

103 stop_reason="maxTokens", 

104 ) 

105 

106 except Exception as e: 

107 logger.error(f"Sampling error: {e}") 

108 raise SamplingError(str(e)) 

109 

110 def _select_model(self, preferences: ModelPreferences) -> str: 

111 """Select model based on preferences. 

112 

113 Args: 

114 preferences: Model selection preferences 

115 

116 Returns: 

117 Selected model name 

118 

119 Raises: 

120 SamplingError: If no suitable model found 

121 """ 

122 # Check model hints first 

123 if preferences.hints: 

124 for hint in preferences.hints: 124 ↛ 130line 124 didn't jump to line 130 because the loop on line 124 didn't complete

125 for model in self._supported_models: 125 ↛ 124line 125 didn't jump to line 124 because the loop on line 125 didn't complete

126 if hint.name and hint.name in model: 

127 return model 

128 

129 # Score models on preferences 

130 best_score = -1 

131 best_model = None 

132 

133 for model, caps in self._supported_models.items(): 

134 cost_score = caps[0] * (1 - preferences.cost_priority) 

135 speed_score = caps[1] * preferences.speed_priority 

136 intel_score = caps[2] * preferences.intelligence_priority 

137 

138 total_score = (cost_score + speed_score + intel_score) / 3 

139 

140 if total_score > best_score: 

141 best_score = total_score 

142 best_model = model 

143 

144 if not best_model: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 raise SamplingError("No suitable model found") 

146 

147 return best_model 

148 

149 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]: 

150 """Add context to messages. 

151 

152 Args: 

153 _db: Database session 

154 messages: Message list 

155 _context_type: Context inclusion type 

156 

157 Returns: 

158 Messages with added context 

159 """ 

160 # TODO: Implement context gathering based on type 

161 # For now return original messages 

162 return messages 

163 

164 def _validate_message(self, message: Dict[str, Any]) -> bool: 

165 """Validate message format. 

166 

167 Args: 

168 message: Message to validate 

169 

170 Returns: 

171 True if valid 

172 """ 

173 try: 

174 # Must have role and content 

175 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 return False 

177 

178 # Content must be valid 

179 content = message["content"] 

180 if content.get("type") == "text": 

181 if not isinstance(content.get("text"), str): 

182 return False 

183 elif content.get("type") == "image": 183 ↛ 187line 183 didn't jump to line 187 because the condition on line 183 was always true

184 if not (content.get("data") and content.get("mime_type")): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 return False 

186 else: 

187 return False 

188 

189 return True 

190 

191 except Exception: 

192 return False 

193 

194 def _mock_sample( 

195 self, 

196 messages: List[Dict[str, Any]], 

197 ) -> str: 

198 """Mock sampling response for testing. 

199 

200 Args: 

201 messages: Input messages 

202 

203 Returns: 

204 Sampled response text 

205 """ 

206 # Extract last user message 

207 last_msg = None 

208 for msg in reversed(messages): 208 ↛ 213line 208 didn't jump to line 213 because the loop on line 208 didn't complete

209 if msg["role"] == "user": 209 ↛ 208line 209 didn't jump to line 208 because the condition on line 209 was always true

210 last_msg = msg 

211 break 

212 

213 if not last_msg: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 return "I'm not sure what to respond to." 

215 

216 # Get user text 

217 user_text = "" 

218 content = last_msg["content"] 

219 if content["type"] == "text": 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true

220 user_text = content["text"] 

221 elif content["type"] == "image": 

222 user_text = "I see the image you shared." 

223 

224 # Generate simple response 

225 return f"You said: {user_text}\nHere is my response..."