Coverage for mcpgateway/federation/forward.py: 53%
122 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:23 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:23 +0100
1# -*- coding: utf-8 -*-
2"""Federation Request Forwarding.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements request forwarding for federated MCP Gateways.
9It handles:
10- Request routing to appropriate gateways
11- Response aggregation
12- Error handling and retry logic
13- Request/response transformation
14"""
16import asyncio
17import logging
18from datetime import datetime
19from typing import Any, Dict, List, Optional, Set, Tuple, Union
21import httpx
22from sqlalchemy import select
23from sqlalchemy.orm import Session
25from mcpgateway.config import settings
26from mcpgateway.db import Gateway as DbGateway
27from mcpgateway.db import Tool as DbTool
28from mcpgateway.types import ToolResult
30logger = logging.getLogger(__name__)
33class ForwardingError(Exception):
34 """Base class for forwarding-related errors."""
37class ForwardingService:
38 """Service for handling request forwarding across gateways.
40 Handles:
41 - Request routing
42 - Response aggregation
43 - Error handling
44 - Request transformation
45 """
47 def __init__(self):
48 """Initialize forwarding service."""
49 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
51 # Track active requests
52 self._active_requests: Dict[str, asyncio.Task] = {}
54 # Request history for rate limiting
55 self._request_history: Dict[str, List[datetime]] = {}
57 # Cache gateway information
58 self._gateway_tools: Dict[int, Set[str]] = {}
60 async def start(self) -> None:
61 """Start forwarding service."""
62 logger.info("Request forwarding service started")
64 async def stop(self) -> None:
65 """Stop forwarding service."""
66 # Cancel active requests
67 for request_id, task in self._active_requests.items(): 67 ↛ 68line 67 didn't jump to line 68 because the loop on line 67 never started
68 logger.info(f"Cancelling request {request_id}")
69 task.cancel()
70 try:
71 await task
72 except asyncio.CancelledError:
73 pass
75 await self._http_client.aclose()
76 logger.info("Request forwarding service stopped")
78 async def forward_request(
79 self,
80 db: Session,
81 method: str,
82 params: Optional[Dict[str, Any]] = None,
83 target_gateway_id: Optional[int] = None,
84 ) -> Any:
85 """Forward a request to gateway(s).
87 Args:
88 db: Database session
89 method: RPC method name
90 params: Optional method parameters
91 target_gateway_id: Optional specific gateway ID
93 Returns:
94 Forwarded response(s)
96 Raises:
97 ForwardingError: If forwarding fails
98 """
99 try:
100 if target_gateway_id:
101 # Forward to specific gateway
102 return await self._forward_to_gateway(db, target_gateway_id, method, params)
104 # Forward to all relevant gateways
105 return await self._forward_to_all(db, method, params)
107 except Exception as e:
108 raise ForwardingError(f"Forward request failed: {str(e)}")
110 async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
111 """Forward a tool invocation request.
113 Args:
114 db: Database session
115 tool_name: Tool to invoke
116 arguments: Tool arguments
118 Returns:
119 Tool result
121 Raises:
122 ForwardingError: If forwarding fails
123 """
124 try:
125 # Find tool
126 tool = db.execute(select(DbTool).where(DbTool.name == tool_name).where(DbTool.is_active)).scalar_one_or_none()
128 if not tool: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 raise ForwardingError(f"Tool not found: {tool_name}")
131 if not tool.gateway_id: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 raise ForwardingError(f"Tool {tool_name} is not federated")
134 # Forward to gateway
135 result = await self._forward_to_gateway(
136 db,
137 tool.gateway_id,
138 "tools/invoke",
139 {"name": tool_name, "arguments": arguments},
140 )
142 # Parse result
143 return ToolResult(
144 content=result.get("content", []),
145 is_error=result.get("is_error", False),
146 )
148 except Exception as e:
149 raise ForwardingError(f"Failed to forward tool request: {str(e)}")
151 async def forward_resource_request(self, db: Session, uri: str) -> Tuple[Union[str, bytes], str]:
152 """Forward a resource read request.
154 Args:
155 db: Database session
156 uri: Resource URI
158 Returns:
159 Tuple of (content, mime_type)
161 Raises:
162 ForwardingError: If forwarding fails
163 """
164 try:
165 # Find gateway for resource
166 gateway = await self._find_resource_gateway(db, uri)
167 if not gateway:
168 raise ForwardingError(f"No gateway found for resource: {uri}")
170 # Forward request
171 result = await self._forward_to_gateway(db, gateway.id, "resources/read", {"uri": uri})
173 # Parse result
174 if "text" in result:
175 return result["text"], result.get("mime_type", "text/plain")
176 if "blob" in result:
177 return result["blob"], result.get("mime_type", "application/octet-stream")
179 raise ForwardingError("Invalid resource response format")
181 except Exception as e:
182 raise ForwardingError(f"Failed to forward resource request: {str(e)}")
184 async def _forward_to_gateway(
185 self,
186 db: Session,
187 gateway_id: int,
188 method: str,
189 params: Optional[Dict[str, Any]] = None,
190 ) -> Any:
191 """Forward request to a specific gateway.
193 Args:
194 db: Database session
195 gateway_id: Gateway to forward to
196 method: RPC method name
197 params: Optional method parameters
199 Returns:
200 Gateway response
202 Raises:
203 ForwardingError: If forwarding fails
204 httpx.TimeoutException: If unable to connect after retries
205 """
206 # Get gateway
207 gateway = db.get(DbGateway, gateway_id)
208 if not gateway or not gateway.is_active: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 raise ForwardingError(f"Gateway not found: {gateway_id}")
211 # Check rate limits
212 if not self._check_rate_limit(gateway.url): 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 raise ForwardingError("Rate limit exceeded")
215 try:
216 # Build request
217 request = {"jsonrpc": "2.0", "id": 1, "method": method}
218 if params: 218 ↛ 222line 218 didn't jump to line 222 because the condition on line 218 was always true
219 request["params"] = params
221 # Send request with retries using the persistent client directly
222 for attempt in range(settings.max_tool_retries): 222 ↛ exitline 222 didn't return from function '_forward_to_gateway' because the loop on line 222 didn't complete
223 try:
224 response = await self._http_client.post(
225 f"{gateway.url}/rpc",
226 json=request,
227 headers=self._get_auth_headers(),
228 )
229 response.raise_for_status()
230 result = response.json()
232 # Update last seen
233 gateway.last_seen = datetime.utcnow()
235 # Handle response
236 if "error" in result: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 raise ForwardingError(f"Gateway error: {result['error'].get('message')}")
238 return result.get("result")
240 except httpx.TimeoutException:
241 if attempt == settings.max_tool_retries - 1:
242 raise
243 await asyncio.sleep(1 * (attempt + 1))
245 except Exception as e:
246 raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}")
248 async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]:
249 """Forward request to all active gateways.
251 Args:
252 db: Database session
253 method: RPC method name
254 params: Optional method parameters
256 Returns:
257 List of responses
259 Raises:
260 ForwardingError: If all forwards fail
261 """
262 # Get active gateways
263 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
265 # Forward to each gateway
266 results = []
267 errors = []
269 for gateway in gateways:
270 try:
271 result = await self._forward_to_gateway(db, gateway.id, method, params)
272 results.append(result)
273 except Exception as e:
274 errors.append(str(e))
276 if not results and errors: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 raise ForwardingError(f"All forwards failed: {'; '.join(errors)}")
279 return results
281 async def _find_resource_gateway(self, db: Session, uri: str) -> Optional[DbGateway]:
282 """Find gateway hosting a resource.
284 Args:
285 db: Database session
286 uri: Resource URI
288 Returns:
289 Gateway record or None
290 """
291 # Get active gateways
292 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
294 # Check each gateway
295 for gateway in gateways:
296 try:
297 resources = await self._forward_to_gateway(db, gateway.id, "resources/list")
298 for resource in resources:
299 if resource.get("uri") == uri:
300 return gateway
301 except Exception as e:
302 logger.error(f"Failed to check gateway {gateway.name} for resource {uri}: {str(e)}")
303 continue
305 return None
307 def _check_rate_limit(self, gateway_url: str) -> bool:
308 """Check if gateway request is within rate limits.
310 Args:
311 gateway_url: Gateway URL
313 Returns:
314 True if request allowed
315 """
316 now = datetime.utcnow()
318 # Clean old history
319 self._request_history[gateway_url] = [t for t in self._request_history.get(gateway_url, []) if (now - t).total_seconds() < 60]
321 # Check limit
322 if len(self._request_history[gateway_url]) >= settings.tool_rate_limit:
323 return False
325 # Record request
326 self._request_history[gateway_url].append(now)
327 return True
329 def _get_auth_headers(self) -> Dict[str, str]:
330 """
331 Get headers for gateway authentication.
333 Returns:
334 dict: Authorization header dict
335 """
336 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
337 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}