Coverage for mcpgateway/federation/manager.py: 44%

187 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:31 +0100

1# -*- coding: utf-8 -*- 

2"""Federation Manager. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module provides the core federation management system for the MCP Gateway. 

9It coordinates: 

10- Gateway discovery and registration 

11- Capability synchronization 

12- Request forwarding 

13- Health monitoring 

14 

15The federation manager serves as the central point for all federation-related 

16operations, coordinating with discovery, sync and forwarding components. 

17""" 

18 

19import asyncio 

20import logging 

21import os 

22from datetime import datetime, timedelta 

23from typing import Any, Dict, List, Optional, Set 

24 

25import httpx 

26from sqlalchemy import select 

27from sqlalchemy.orm import Session 

28 

29from mcpgateway.config import settings 

30from mcpgateway.db import Gateway as DbGateway 

31from mcpgateway.db import Tool as DbTool 

32from mcpgateway.federation.discovery import DiscoveryService 

33from mcpgateway.types import ( 

34 ClientCapabilities, 

35 Implementation, 

36 InitializeRequest, 

37 InitializeResult, 

38 Prompt, 

39 Resource, 

40 ServerCapabilities, 

41 Tool, 

42) 

43 

44logger = logging.getLogger(__name__) 

45 

46PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26") 

47 

48 

49class FederationError(Exception): 

50 """Base class for federation-related errors.""" 

51 

52 

53class FederationManager: 

54 """Manages federation across MCP gateways. 

55 

56 Coordinates: 

57 - Peer discovery and registration 

58 - Capability synchronization 

59 - Request forwarding 

60 - Health monitoring 

61 """ 

62 

63 def __init__(self): 

64 """Initialize federation manager.""" 

65 self._discovery = DiscoveryService() 

66 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

67 

68 # Track active gateways 

69 self._active_gateways: Set[str] = set() 

70 

71 # Background tasks 

72 self._sync_task: Optional[asyncio.Task] = None 

73 self._health_task: Optional[asyncio.Task] = None 

74 

75 async def start(self, db: Session) -> None: 

76 """Start federation system. 

77 

78 Args: 

79 db: Database session 

80 

81 Raises: 

82 Exception: If unable to start federation manager 

83 """ 

84 if not settings.federation_enabled: 

85 logger.info("Federation disabled by configuration") 

86 return 

87 

88 try: 

89 # Start discovery 

90 await self._discovery.start() 

91 

92 # Load existing gateways 

93 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

94 

95 for gateway in gateways: 

96 self._active_gateways.add(gateway.url) 

97 

98 # Start background tasks 

99 self._sync_task = asyncio.create_task(self._run_sync_loop(db)) 

100 self._health_task = asyncio.create_task(self._run_health_loop(db)) 

101 

102 logger.info("Federation manager started") 

103 

104 except Exception as e: 

105 logger.error(f"Failed to start federation manager: {e}") 

106 await self.stop() 

107 raise 

108 

109 async def stop(self) -> None: 

110 """Stop federation system.""" 

111 # Stop background tasks 

112 if self._sync_task: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 self._sync_task.cancel() 

114 try: 

115 await self._sync_task 

116 except asyncio.CancelledError: 

117 pass 

118 

119 if self._health_task: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 self._health_task.cancel() 

121 try: 

122 await self._health_task 

123 except asyncio.CancelledError: 

124 pass 

125 

126 # Stop discovery 

127 await self._discovery.stop() 

128 

129 # Close HTTP client 

130 await self._http_client.aclose() 

131 

132 logger.info("Federation manager stopped") 

133 

134 async def register_gateway(self, db: Session, url: str, name: Optional[str] = None) -> DbGateway: 

135 """Register a new gateway. 

136 

137 Args: 

138 db: Database session 

139 url: Gateway URL 

140 name: Optional gateway name 

141 

142 Returns: 

143 Registered gateway record 

144 

145 Raises: 

146 FederationError: If registration fails 

147 """ 

148 try: 

149 # Initialize connection 

150 capabilities = await self._initialize_gateway(url) 

151 gateway_name = name or f"Gateway-{len(self._active_gateways) + 1}" 

152 

153 # Create gateway record 

154 gateway = DbGateway( 

155 name=gateway_name, 

156 url=url, 

157 capabilities=capabilities.dict(), 

158 last_seen=datetime.utcnow(), 

159 ) 

160 db.add(gateway) 

161 db.commit() 

162 db.refresh(gateway) 

163 

164 # Update tracking 

165 self._active_gateways.add(url) 

166 

167 # Add to discovery 

168 await self._discovery.add_peer(url, source="manual", name=gateway_name) 

169 

170 logger.info(f"Registered gateway: {gateway_name} ({url})") 

171 return gateway 

172 

173 except Exception as e: 

174 db.rollback() 

175 raise FederationError(f"Failed to register gateway: {str(e)}") 

176 

177 async def unregister_gateway(self, db: Session, gateway_id: int) -> None: 

178 """Unregister a gateway. 

179 

180 Args: 

181 db: Database session 

182 gateway_id: Gateway ID to unregister 

183 

184 Raises: 

185 FederationError: If unregistration fails 

186 """ 

187 try: 

188 # Find gateway 

189 gateway = db.get(DbGateway, gateway_id) 

190 if not gateway: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 raise FederationError(f"Gateway not found: {gateway_id}") 

192 

193 # Remove gateway 

194 gateway.is_active = False 

195 gateway.updated_at = datetime.utcnow() 

196 

197 # Remove associated tools 

198 db.execute(select(DbTool).where(DbTool.gateway_id == gateway_id)).delete() 

199 

200 db.commit() 

201 

202 # Update tracking 

203 self._active_gateways.discard(gateway.url) 

204 

205 # Remove from discovery 

206 await self._discovery.remove_peer(gateway.url) 

207 

208 logger.info(f"Unregistered gateway: {gateway.name}") 

209 

210 except Exception as e: 

211 db.rollback() 

212 raise FederationError(f"Failed to unregister gateway: {str(e)}") 

213 

214 async def get_gateway_tools(self, db: Session, gateway_id: int) -> List[Tool]: 

215 """Get tools provided by a gateway. 

216 

217 Args: 

218 db: Database session 

219 gateway_id: Gateway ID 

220 

221 Returns: 

222 List of gateway tools 

223 

224 Raises: 

225 FederationError: If tool list cannot be retrieved 

226 """ 

227 gateway = db.get(DbGateway, gateway_id) 

228 if not gateway or not gateway.is_active: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 raise FederationError(f"Gateway not found: {gateway_id}") 

230 

231 try: 

232 # Get tool list 

233 tools = await self.forward_request(gateway, "tools/list") 

234 return [Tool.parse_obj(t) for t in tools] 

235 

236 except Exception as e: 

237 raise FederationError(f"Failed to get tools from {gateway.name}: {str(e)}") 

238 

239 async def get_gateway_resources(self, db: Session, gateway_id: int) -> List[Resource]: 

240 """Get resources provided by a gateway. 

241 

242 Args: 

243 db: Database session 

244 gateway_id: Gateway ID 

245 

246 Returns: 

247 List of gateway resources 

248 

249 Raises: 

250 FederationError: If resource list cannot be retrieved 

251 """ 

252 gateway = db.get(DbGateway, gateway_id) 

253 if not gateway or not gateway.is_active: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true

254 raise FederationError(f"Gateway not found: {gateway_id}") 

255 

256 try: 

257 # Get resource list 

258 resources = await self.forward_request(gateway, "resources/list") 

259 return [Resource.parse_obj(r) for r in resources] 

260 

261 except Exception as e: 

262 raise FederationError(f"Failed to get resources from {gateway.name}: {str(e)}") 

263 

264 async def get_gateway_prompts(self, db: Session, gateway_id: int) -> List[Prompt]: 

265 """Get prompts provided by a gateway. 

266 

267 Args: 

268 db: Database session 

269 gateway_id: Gateway ID 

270 

271 Returns: 

272 List of gateway prompts 

273 

274 Raises: 

275 FederationError: If prompt list cannot be retrieved 

276 """ 

277 gateway = db.get(DbGateway, gateway_id) 

278 if not gateway or not gateway.is_active: 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true

279 raise FederationError(f"Gateway not found: {gateway_id}") 

280 

281 try: 

282 # Get prompt list 

283 prompts = await self.forward_request(gateway, "prompts/list") 

284 return [Prompt.parse_obj(p) for p in prompts] 

285 

286 except Exception as e: 

287 raise FederationError(f"Failed to get prompts from {gateway.name}: {str(e)}") 

288 

289 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any: 

290 """Forward a request to a gateway. 

291 

292 Args: 

293 gateway: Gateway to forward to 

294 method: RPC method name 

295 params: Optional method parameters 

296 

297 Returns: 

298 Gateway response 

299 

300 Raises: 

301 FederationError: If request forwarding fails 

302 """ 

303 try: 

304 # Build request 

305 request = {"jsonrpc": "2.0", "id": 1, "method": method} 

306 if params: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 request["params"] = params 

308 

309 # Send request using the persistent client directly 

310 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers()) 

311 response.raise_for_status() 

312 result = response.json() 

313 

314 # Update last seen 

315 gateway.last_seen = datetime.utcnow() 

316 

317 # Handle response 

318 if "error" in result: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 raise FederationError(f"Gateway error: {result['error'].get('message')}") 

320 return result.get("result") 

321 

322 except Exception as e: 

323 raise FederationError(f"Failed to forward request to {gateway.name}: {str(e)}") 

324 

325 async def _run_sync_loop(self, db: Session) -> None: 

326 """ 

327 Run periodic gateway synchronization. 

328 

329 Args: 

330 db: Session object 

331 """ 

332 while True: 

333 try: 

334 # Process discovered peers 

335 discovered = self._discovery.get_discovered_peers() 

336 for peer in discovered: 

337 if peer.url not in self._active_gateways: 

338 try: 

339 await self.register_gateway(db, peer.url, peer.name) 

340 except Exception as e: 

341 logger.warning(f"Failed to register discovered peer {peer.url}: {e}") 

342 

343 # Sync active gateways 

344 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

345 

346 for gateway in gateways: 

347 try: 

348 # Update capabilities 

349 capabilities = await self._initialize_gateway(gateway.url) 

350 gateway.capabilities = capabilities.dict() 

351 gateway.last_seen = datetime.utcnow() 

352 gateway.is_active = True 

353 

354 except Exception as e: 

355 logger.warning(f"Failed to sync gateway {gateway.name}: {e}") 

356 

357 db.commit() 

358 

359 except Exception as e: 

360 logger.error(f"Sync loop error: {e}") 

361 db.rollback() 

362 

363 await asyncio.sleep(settings.federation_sync_interval) 

364 

365 async def _run_health_loop(self, db: Session) -> None: 

366 """ 

367 Run periodic gateway health checks. 

368 

369 Args: 

370 db: Session object 

371 """ 

372 while True: 

373 try: 

374 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

375 

376 for gateway in gateways: 

377 try: 

378 # Check gateway health 

379 await self._check_gateway_health(gateway) 

380 except Exception as e: 

381 logger.warning(f"Health check failed for {gateway.name}: {e}") 

382 # Mark inactive if not seen recently 

383 if datetime.utcnow() - gateway.last_seen > timedelta(minutes=5): 

384 gateway.is_active = False 

385 self._active_gateways.discard(gateway.url) 

386 

387 db.commit() 

388 

389 except Exception as e: 

390 logger.error(f"Health check error: {e}") 

391 db.rollback() 

392 

393 await asyncio.sleep(settings.health_check_interval) 

394 

395 async def _initialize_gateway(self, url: str) -> ServerCapabilities: 

396 """Initialize connection to a gateway. 

397 

398 Args: 

399 url: Gateway URL 

400 

401 Returns: 

402 Gateway capabilities 

403 

404 Raises: 

405 FederationError: If initialization fails 

406 """ 

407 try: 

408 # Build initialize request 

409 request = InitializeRequest( 

410 protocol_version=PROTOCOL_VERSION, 

411 capabilities=ClientCapabilities(roots={"listChanged": True}, sampling={}), 

412 client_info=Implementation(name=settings.app_name, version="1.0.0"), 

413 ) 

414 

415 # Send request using the persistent client directly 

416 response = await self._http_client.post( 

417 f"{url}/initialize", 

418 json=request.dict(), 

419 headers=self._get_auth_headers(), 

420 ) 

421 response.raise_for_status() 

422 result = InitializeResult.parse_obj(response.json()) 

423 

424 # Verify protocol version 

425 if result.protocol_version != PROTOCOL_VERSION: 

426 raise FederationError(f"Unsupported protocol version: {result.protocol_version}") 

427 

428 return result.capabilities 

429 

430 except Exception as e: 

431 raise FederationError(f"Failed to initialize gateway: {str(e)}") 

432 

433 async def _check_gateway_health(self, gateway: DbGateway) -> bool: 

434 """Check if a gateway is healthy. 

435 

436 Args: 

437 gateway: Gateway to check 

438 

439 Returns: 

440 True if gateway is healthy 

441 

442 Raises: 

443 FederationError: If health check fails 

444 """ 

445 try: 

446 await self._initialize_gateway(gateway.url) 

447 return True 

448 except Exception as e: 

449 raise FederationError(f"Gateway health check failed: {str(e)}") 

450 

451 def _get_auth_headers(self) -> Dict[str, str]: 

452 """ 

453 Get headers for gateway authentication. 

454 

455 Returns: 

456 dict: Headers to be used in request 

457 """ 

458 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

459 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}