Coverage for mcpgateway/handlers/sampling.py: 77%
87 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 16:34 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 16:34 +0100
1# -*- coding: utf-8 -*-
2"""MCP Sampling Handler Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements the sampling handler for MCP LLM interactions.
9It handles model selection, sampling preferences, and message generation.
10"""
12import logging
13from typing import Any, Dict, List
15from sqlalchemy.orm import Session
17from mcpgateway.types import CreateMessageResult, ModelPreferences, Role, TextContent
19logger = logging.getLogger(__name__)
22class SamplingError(Exception):
23 """Base class for sampling errors."""
26class SamplingHandler:
27 """MCP sampling request handler.
29 Handles:
30 - Model selection based on preferences
31 - Message sampling requests
32 - Context management
33 - Content validation
34 """
36 def __init__(self):
37 """Initialize sampling handler."""
38 self._supported_models = {
39 # Maps model names to capabilities scores (cost, speed, intelligence)
40 "claude-3-haiku": (0.8, 0.9, 0.7),
41 "claude-3-sonnet": (0.5, 0.7, 0.9),
42 "claude-3-opus": (0.2, 0.5, 1.0),
43 "gemini-1.5-pro": (0.6, 0.8, 0.8),
44 }
46 async def initialize(self) -> None:
47 """Initialize sampling handler."""
48 logger.info("Initializing sampling handler")
50 async def shutdown(self) -> None:
51 """Shutdown sampling handler."""
52 logger.info("Shutting down sampling handler")
54 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult:
55 """Create message from sampling request.
57 Args:
58 db: Database session
59 request: Sampling request parameters
61 Returns:
62 Sampled message result
64 Raises:
65 SamplingError: If sampling fails
66 """
67 try:
68 # Extract request parameters
69 messages = request.get("messages", [])
70 max_tokens = request.get("maxTokens")
71 model_prefs = ModelPreferences.parse_obj(request.get("modelPreferences", {}))
72 include_context = request.get("includeContext", "none")
73 request.get("metadata", {})
75 # Validate request
76 if not messages:
77 raise SamplingError("No messages provided")
78 if not max_tokens: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 raise SamplingError("Max tokens not specified")
81 # Select model
82 model = self._select_model(model_prefs)
83 logger.info(f"Selected model: {model}")
85 # Include context if requested
86 if include_context != "none": 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 messages = await self._add_context(db, messages, include_context)
89 # Validate messages
90 for msg in messages:
91 if not self._validate_message(msg): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 raise SamplingError(f"Invalid message format: {msg}")
94 # TODO: Sample from selected model
95 # For now return mock response
96 response = self._mock_sample(messages=messages)
98 # Convert to result
99 return CreateMessageResult(
100 content=TextContent(type="text", text=response),
101 model=model,
102 role=Role.ASSISTANT,
103 stop_reason="maxTokens",
104 )
106 except Exception as e:
107 logger.error(f"Sampling error: {e}")
108 raise SamplingError(str(e))
110 def _select_model(self, preferences: ModelPreferences) -> str:
111 """Select model based on preferences.
113 Args:
114 preferences: Model selection preferences
116 Returns:
117 Selected model name
119 Raises:
120 SamplingError: If no suitable model found
121 """
122 # Check model hints first
123 if preferences.hints:
124 for hint in preferences.hints: 124 ↛ 130line 124 didn't jump to line 130 because the loop on line 124 didn't complete
125 for model in self._supported_models: 125 ↛ 124line 125 didn't jump to line 124 because the loop on line 125 didn't complete
126 if hint.name and hint.name in model:
127 return model
129 # Score models on preferences
130 best_score = -1
131 best_model = None
133 for model, caps in self._supported_models.items():
134 cost_score = caps[0] * (1 - preferences.cost_priority)
135 speed_score = caps[1] * preferences.speed_priority
136 intel_score = caps[2] * preferences.intelligence_priority
138 total_score = (cost_score + speed_score + intel_score) / 3
140 if total_score > best_score:
141 best_score = total_score
142 best_model = model
144 if not best_model: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 raise SamplingError("No suitable model found")
147 return best_model
149 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]:
150 """Add context to messages.
152 Args:
153 _db: Database session
154 messages: Message list
155 _context_type: Context inclusion type
157 Returns:
158 Messages with added context
159 """
160 # TODO: Implement context gathering based on type
161 # For now return original messages
162 return messages
164 def _validate_message(self, message: Dict[str, Any]) -> bool:
165 """Validate message format.
167 Args:
168 message: Message to validate
170 Returns:
171 True if valid
172 """
173 try:
174 # Must have role and content
175 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 return False
178 # Content must be valid
179 content = message["content"]
180 if content.get("type") == "text":
181 if not isinstance(content.get("text"), str):
182 return False
183 elif content.get("type") == "image": 183 ↛ 187line 183 didn't jump to line 187 because the condition on line 183 was always true
184 if not (content.get("data") and content.get("mime_type")): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 return False
186 else:
187 return False
189 return True
191 except Exception:
192 return False
194 def _mock_sample(
195 self,
196 messages: List[Dict[str, Any]],
197 ) -> str:
198 """Mock sampling response for testing.
200 Args:
201 messages: Input messages
203 Returns:
204 Sampled response text
205 """
206 # Extract last user message
207 last_msg = None
208 for msg in reversed(messages): 208 ↛ 213line 208 didn't jump to line 213 because the loop on line 208 didn't complete
209 if msg["role"] == "user": 209 ↛ 208line 209 didn't jump to line 208 because the condition on line 209 was always true
210 last_msg = msg
211 break
213 if not last_msg: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 return "I'm not sure what to respond to."
216 # Get user text
217 user_text = ""
218 content = last_msg["content"]
219 if content["type"] == "text": 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true
220 user_text = content["text"]
221 elif content["type"] == "image":
222 user_text = "I see the image you shared."
224 # Generate simple response
225 return f"You said: {user_text}\nHere is my response..."