Coverage for mcpgateway/federation/forward.py: 53%

122 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:23 +0100

1# -*- coding: utf-8 -*- 

2"""Federation Request Forwarding. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements request forwarding for federated MCP Gateways. 

9It handles: 

10- Request routing to appropriate gateways 

11- Response aggregation 

12- Error handling and retry logic 

13- Request/response transformation 

14""" 

15 

16import asyncio 

17import logging 

18from datetime import datetime 

19from typing import Any, Dict, List, Optional, Set, Tuple, Union 

20 

21import httpx 

22from sqlalchemy import select 

23from sqlalchemy.orm import Session 

24 

25from mcpgateway.config import settings 

26from mcpgateway.db import Gateway as DbGateway 

27from mcpgateway.db import Tool as DbTool 

28from mcpgateway.types import ToolResult 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class ForwardingError(Exception): 

34 """Base class for forwarding-related errors.""" 

35 

36 

37class ForwardingService: 

38 """Service for handling request forwarding across gateways. 

39 

40 Handles: 

41 - Request routing 

42 - Response aggregation 

43 - Error handling 

44 - Request transformation 

45 """ 

46 

47 def __init__(self): 

48 """Initialize forwarding service.""" 

49 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

50 

51 # Track active requests 

52 self._active_requests: Dict[str, asyncio.Task] = {} 

53 

54 # Request history for rate limiting 

55 self._request_history: Dict[str, List[datetime]] = {} 

56 

57 # Cache gateway information 

58 self._gateway_tools: Dict[int, Set[str]] = {} 

59 

60 async def start(self) -> None: 

61 """Start forwarding service.""" 

62 logger.info("Request forwarding service started") 

63 

64 async def stop(self) -> None: 

65 """Stop forwarding service.""" 

66 # Cancel active requests 

67 for request_id, task in self._active_requests.items(): 67 ↛ 68line 67 didn't jump to line 68 because the loop on line 67 never started

68 logger.info(f"Cancelling request {request_id}") 

69 task.cancel() 

70 try: 

71 await task 

72 except asyncio.CancelledError: 

73 pass 

74 

75 await self._http_client.aclose() 

76 logger.info("Request forwarding service stopped") 

77 

78 async def forward_request( 

79 self, 

80 db: Session, 

81 method: str, 

82 params: Optional[Dict[str, Any]] = None, 

83 target_gateway_id: Optional[int] = None, 

84 ) -> Any: 

85 """Forward a request to gateway(s). 

86 

87 Args: 

88 db: Database session 

89 method: RPC method name 

90 params: Optional method parameters 

91 target_gateway_id: Optional specific gateway ID 

92 

93 Returns: 

94 Forwarded response(s) 

95 

96 Raises: 

97 ForwardingError: If forwarding fails 

98 """ 

99 try: 

100 if target_gateway_id: 

101 # Forward to specific gateway 

102 return await self._forward_to_gateway(db, target_gateway_id, method, params) 

103 

104 # Forward to all relevant gateways 

105 return await self._forward_to_all(db, method, params) 

106 

107 except Exception as e: 

108 raise ForwardingError(f"Forward request failed: {str(e)}") 

109 

110 async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult: 

111 """Forward a tool invocation request. 

112 

113 Args: 

114 db: Database session 

115 tool_name: Tool to invoke 

116 arguments: Tool arguments 

117 

118 Returns: 

119 Tool result 

120 

121 Raises: 

122 ForwardingError: If forwarding fails 

123 """ 

124 try: 

125 # Find tool 

126 tool = db.execute(select(DbTool).where(DbTool.name == tool_name).where(DbTool.is_active)).scalar_one_or_none() 

127 

128 if not tool: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 raise ForwardingError(f"Tool not found: {tool_name}") 

130 

131 if not tool.gateway_id: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 raise ForwardingError(f"Tool {tool_name} is not federated") 

133 

134 # Forward to gateway 

135 result = await self._forward_to_gateway( 

136 db, 

137 tool.gateway_id, 

138 "tools/invoke", 

139 {"name": tool_name, "arguments": arguments}, 

140 ) 

141 

142 # Parse result 

143 return ToolResult( 

144 content=result.get("content", []), 

145 is_error=result.get("is_error", False), 

146 ) 

147 

148 except Exception as e: 

149 raise ForwardingError(f"Failed to forward tool request: {str(e)}") 

150 

151 async def forward_resource_request(self, db: Session, uri: str) -> Tuple[Union[str, bytes], str]: 

152 """Forward a resource read request. 

153 

154 Args: 

155 db: Database session 

156 uri: Resource URI 

157 

158 Returns: 

159 Tuple of (content, mime_type) 

160 

161 Raises: 

162 ForwardingError: If forwarding fails 

163 """ 

164 try: 

165 # Find gateway for resource 

166 gateway = await self._find_resource_gateway(db, uri) 

167 if not gateway: 

168 raise ForwardingError(f"No gateway found for resource: {uri}") 

169 

170 # Forward request 

171 result = await self._forward_to_gateway(db, gateway.id, "resources/read", {"uri": uri}) 

172 

173 # Parse result 

174 if "text" in result: 

175 return result["text"], result.get("mime_type", "text/plain") 

176 if "blob" in result: 

177 return result["blob"], result.get("mime_type", "application/octet-stream") 

178 

179 raise ForwardingError("Invalid resource response format") 

180 

181 except Exception as e: 

182 raise ForwardingError(f"Failed to forward resource request: {str(e)}") 

183 

184 async def _forward_to_gateway( 

185 self, 

186 db: Session, 

187 gateway_id: int, 

188 method: str, 

189 params: Optional[Dict[str, Any]] = None, 

190 ) -> Any: 

191 """Forward request to a specific gateway. 

192 

193 Args: 

194 db: Database session 

195 gateway_id: Gateway to forward to 

196 method: RPC method name 

197 params: Optional method parameters 

198 

199 Returns: 

200 Gateway response 

201 

202 Raises: 

203 ForwardingError: If forwarding fails 

204 httpx.TimeoutException: If unable to connect after retries 

205 """ 

206 # Get gateway 

207 gateway = db.get(DbGateway, gateway_id) 

208 if not gateway or not gateway.is_active: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 raise ForwardingError(f"Gateway not found: {gateway_id}") 

210 

211 # Check rate limits 

212 if not self._check_rate_limit(gateway.url): 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 raise ForwardingError("Rate limit exceeded") 

214 

215 try: 

216 # Build request 

217 request = {"jsonrpc": "2.0", "id": 1, "method": method} 

218 if params: 218 ↛ 222line 218 didn't jump to line 222 because the condition on line 218 was always true

219 request["params"] = params 

220 

221 # Send request with retries using the persistent client directly 

222 for attempt in range(settings.max_tool_retries): 222 ↛ exitline 222 didn't return from function '_forward_to_gateway' because the loop on line 222 didn't complete

223 try: 

224 response = await self._http_client.post( 

225 f"{gateway.url}/rpc", 

226 json=request, 

227 headers=self._get_auth_headers(), 

228 ) 

229 response.raise_for_status() 

230 result = response.json() 

231 

232 # Update last seen 

233 gateway.last_seen = datetime.utcnow() 

234 

235 # Handle response 

236 if "error" in result: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 raise ForwardingError(f"Gateway error: {result['error'].get('message')}") 

238 return result.get("result") 

239 

240 except httpx.TimeoutException: 

241 if attempt == settings.max_tool_retries - 1: 

242 raise 

243 await asyncio.sleep(1 * (attempt + 1)) 

244 

245 except Exception as e: 

246 raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}") 

247 

248 async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: 

249 """Forward request to all active gateways. 

250 

251 Args: 

252 db: Database session 

253 method: RPC method name 

254 params: Optional method parameters 

255 

256 Returns: 

257 List of responses 

258 

259 Raises: 

260 ForwardingError: If all forwards fail 

261 """ 

262 # Get active gateways 

263 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

264 

265 # Forward to each gateway 

266 results = [] 

267 errors = [] 

268 

269 for gateway in gateways: 

270 try: 

271 result = await self._forward_to_gateway(db, gateway.id, method, params) 

272 results.append(result) 

273 except Exception as e: 

274 errors.append(str(e)) 

275 

276 if not results and errors: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true

277 raise ForwardingError(f"All forwards failed: {'; '.join(errors)}") 

278 

279 return results 

280 

281 async def _find_resource_gateway(self, db: Session, uri: str) -> Optional[DbGateway]: 

282 """Find gateway hosting a resource. 

283 

284 Args: 

285 db: Database session 

286 uri: Resource URI 

287 

288 Returns: 

289 Gateway record or None 

290 """ 

291 # Get active gateways 

292 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

293 

294 # Check each gateway 

295 for gateway in gateways: 

296 try: 

297 resources = await self._forward_to_gateway(db, gateway.id, "resources/list") 

298 for resource in resources: 

299 if resource.get("uri") == uri: 

300 return gateway 

301 except Exception as e: 

302 logger.error(f"Failed to check gateway {gateway.name} for resource {uri}: {str(e)}") 

303 continue 

304 

305 return None 

306 

307 def _check_rate_limit(self, gateway_url: str) -> bool: 

308 """Check if gateway request is within rate limits. 

309 

310 Args: 

311 gateway_url: Gateway URL 

312 

313 Returns: 

314 True if request allowed 

315 """ 

316 now = datetime.utcnow() 

317 

318 # Clean old history 

319 self._request_history[gateway_url] = [t for t in self._request_history.get(gateway_url, []) if (now - t).total_seconds() < 60] 

320 

321 # Check limit 

322 if len(self._request_history[gateway_url]) >= settings.tool_rate_limit: 

323 return False 

324 

325 # Record request 

326 self._request_history[gateway_url].append(now) 

327 return True 

328 

329 def _get_auth_headers(self) -> Dict[str, str]: 

330 """ 

331 Get headers for gateway authentication. 

332 

333 Returns: 

334 dict: Authorization header dict 

335 """ 

336 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

337 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}