Coverage for mcpgateway/federation/manager.py: 44%
187 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:31 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:31 +0100
1# -*- coding: utf-8 -*-
2"""Federation Manager.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module provides the core federation management system for the MCP Gateway.
9It coordinates:
10- Gateway discovery and registration
11- Capability synchronization
12- Request forwarding
13- Health monitoring
15The federation manager serves as the central point for all federation-related
16operations, coordinating with discovery, sync and forwarding components.
17"""
19import asyncio
20import logging
21import os
22from datetime import datetime, timedelta
23from typing import Any, Dict, List, Optional, Set
25import httpx
26from sqlalchemy import select
27from sqlalchemy.orm import Session
29from mcpgateway.config import settings
30from mcpgateway.db import Gateway as DbGateway
31from mcpgateway.db import Tool as DbTool
32from mcpgateway.federation.discovery import DiscoveryService
33from mcpgateway.types import (
34 ClientCapabilities,
35 Implementation,
36 InitializeRequest,
37 InitializeResult,
38 Prompt,
39 Resource,
40 ServerCapabilities,
41 Tool,
42)
44logger = logging.getLogger(__name__)
46PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26")
49class FederationError(Exception):
50 """Base class for federation-related errors."""
53class FederationManager:
54 """Manages federation across MCP gateways.
56 Coordinates:
57 - Peer discovery and registration
58 - Capability synchronization
59 - Request forwarding
60 - Health monitoring
61 """
63 def __init__(self):
64 """Initialize federation manager."""
65 self._discovery = DiscoveryService()
66 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
68 # Track active gateways
69 self._active_gateways: Set[str] = set()
71 # Background tasks
72 self._sync_task: Optional[asyncio.Task] = None
73 self._health_task: Optional[asyncio.Task] = None
75 async def start(self, db: Session) -> None:
76 """Start federation system.
78 Args:
79 db: Database session
81 Raises:
82 Exception: If unable to start federation manager
83 """
84 if not settings.federation_enabled:
85 logger.info("Federation disabled by configuration")
86 return
88 try:
89 # Start discovery
90 await self._discovery.start()
92 # Load existing gateways
93 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
95 for gateway in gateways:
96 self._active_gateways.add(gateway.url)
98 # Start background tasks
99 self._sync_task = asyncio.create_task(self._run_sync_loop(db))
100 self._health_task = asyncio.create_task(self._run_health_loop(db))
102 logger.info("Federation manager started")
104 except Exception as e:
105 logger.error(f"Failed to start federation manager: {e}")
106 await self.stop()
107 raise
109 async def stop(self) -> None:
110 """Stop federation system."""
111 # Stop background tasks
112 if self._sync_task: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 self._sync_task.cancel()
114 try:
115 await self._sync_task
116 except asyncio.CancelledError:
117 pass
119 if self._health_task: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 self._health_task.cancel()
121 try:
122 await self._health_task
123 except asyncio.CancelledError:
124 pass
126 # Stop discovery
127 await self._discovery.stop()
129 # Close HTTP client
130 await self._http_client.aclose()
132 logger.info("Federation manager stopped")
134 async def register_gateway(self, db: Session, url: str, name: Optional[str] = None) -> DbGateway:
135 """Register a new gateway.
137 Args:
138 db: Database session
139 url: Gateway URL
140 name: Optional gateway name
142 Returns:
143 Registered gateway record
145 Raises:
146 FederationError: If registration fails
147 """
148 try:
149 # Initialize connection
150 capabilities = await self._initialize_gateway(url)
151 gateway_name = name or f"Gateway-{len(self._active_gateways) + 1}"
153 # Create gateway record
154 gateway = DbGateway(
155 name=gateway_name,
156 url=url,
157 capabilities=capabilities.dict(),
158 last_seen=datetime.utcnow(),
159 )
160 db.add(gateway)
161 db.commit()
162 db.refresh(gateway)
164 # Update tracking
165 self._active_gateways.add(url)
167 # Add to discovery
168 await self._discovery.add_peer(url, source="manual", name=gateway_name)
170 logger.info(f"Registered gateway: {gateway_name} ({url})")
171 return gateway
173 except Exception as e:
174 db.rollback()
175 raise FederationError(f"Failed to register gateway: {str(e)}")
177 async def unregister_gateway(self, db: Session, gateway_id: int) -> None:
178 """Unregister a gateway.
180 Args:
181 db: Database session
182 gateway_id: Gateway ID to unregister
184 Raises:
185 FederationError: If unregistration fails
186 """
187 try:
188 # Find gateway
189 gateway = db.get(DbGateway, gateway_id)
190 if not gateway: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 raise FederationError(f"Gateway not found: {gateway_id}")
193 # Remove gateway
194 gateway.is_active = False
195 gateway.updated_at = datetime.utcnow()
197 # Remove associated tools
198 db.execute(select(DbTool).where(DbTool.gateway_id == gateway_id)).delete()
200 db.commit()
202 # Update tracking
203 self._active_gateways.discard(gateway.url)
205 # Remove from discovery
206 await self._discovery.remove_peer(gateway.url)
208 logger.info(f"Unregistered gateway: {gateway.name}")
210 except Exception as e:
211 db.rollback()
212 raise FederationError(f"Failed to unregister gateway: {str(e)}")
214 async def get_gateway_tools(self, db: Session, gateway_id: int) -> List[Tool]:
215 """Get tools provided by a gateway.
217 Args:
218 db: Database session
219 gateway_id: Gateway ID
221 Returns:
222 List of gateway tools
224 Raises:
225 FederationError: If tool list cannot be retrieved
226 """
227 gateway = db.get(DbGateway, gateway_id)
228 if not gateway or not gateway.is_active: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 raise FederationError(f"Gateway not found: {gateway_id}")
231 try:
232 # Get tool list
233 tools = await self.forward_request(gateway, "tools/list")
234 return [Tool.parse_obj(t) for t in tools]
236 except Exception as e:
237 raise FederationError(f"Failed to get tools from {gateway.name}: {str(e)}")
239 async def get_gateway_resources(self, db: Session, gateway_id: int) -> List[Resource]:
240 """Get resources provided by a gateway.
242 Args:
243 db: Database session
244 gateway_id: Gateway ID
246 Returns:
247 List of gateway resources
249 Raises:
250 FederationError: If resource list cannot be retrieved
251 """
252 gateway = db.get(DbGateway, gateway_id)
253 if not gateway or not gateway.is_active: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true
254 raise FederationError(f"Gateway not found: {gateway_id}")
256 try:
257 # Get resource list
258 resources = await self.forward_request(gateway, "resources/list")
259 return [Resource.parse_obj(r) for r in resources]
261 except Exception as e:
262 raise FederationError(f"Failed to get resources from {gateway.name}: {str(e)}")
264 async def get_gateway_prompts(self, db: Session, gateway_id: int) -> List[Prompt]:
265 """Get prompts provided by a gateway.
267 Args:
268 db: Database session
269 gateway_id: Gateway ID
271 Returns:
272 List of gateway prompts
274 Raises:
275 FederationError: If prompt list cannot be retrieved
276 """
277 gateway = db.get(DbGateway, gateway_id)
278 if not gateway or not gateway.is_active: 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true
279 raise FederationError(f"Gateway not found: {gateway_id}")
281 try:
282 # Get prompt list
283 prompts = await self.forward_request(gateway, "prompts/list")
284 return [Prompt.parse_obj(p) for p in prompts]
286 except Exception as e:
287 raise FederationError(f"Failed to get prompts from {gateway.name}: {str(e)}")
289 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
290 """Forward a request to a gateway.
292 Args:
293 gateway: Gateway to forward to
294 method: RPC method name
295 params: Optional method parameters
297 Returns:
298 Gateway response
300 Raises:
301 FederationError: If request forwarding fails
302 """
303 try:
304 # Build request
305 request = {"jsonrpc": "2.0", "id": 1, "method": method}
306 if params: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 request["params"] = params
309 # Send request using the persistent client directly
310 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers())
311 response.raise_for_status()
312 result = response.json()
314 # Update last seen
315 gateway.last_seen = datetime.utcnow()
317 # Handle response
318 if "error" in result: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true
319 raise FederationError(f"Gateway error: {result['error'].get('message')}")
320 return result.get("result")
322 except Exception as e:
323 raise FederationError(f"Failed to forward request to {gateway.name}: {str(e)}")
325 async def _run_sync_loop(self, db: Session) -> None:
326 """
327 Run periodic gateway synchronization.
329 Args:
330 db: Session object
331 """
332 while True:
333 try:
334 # Process discovered peers
335 discovered = self._discovery.get_discovered_peers()
336 for peer in discovered:
337 if peer.url not in self._active_gateways:
338 try:
339 await self.register_gateway(db, peer.url, peer.name)
340 except Exception as e:
341 logger.warning(f"Failed to register discovered peer {peer.url}: {e}")
343 # Sync active gateways
344 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
346 for gateway in gateways:
347 try:
348 # Update capabilities
349 capabilities = await self._initialize_gateway(gateway.url)
350 gateway.capabilities = capabilities.dict()
351 gateway.last_seen = datetime.utcnow()
352 gateway.is_active = True
354 except Exception as e:
355 logger.warning(f"Failed to sync gateway {gateway.name}: {e}")
357 db.commit()
359 except Exception as e:
360 logger.error(f"Sync loop error: {e}")
361 db.rollback()
363 await asyncio.sleep(settings.federation_sync_interval)
365 async def _run_health_loop(self, db: Session) -> None:
366 """
367 Run periodic gateway health checks.
369 Args:
370 db: Session object
371 """
372 while True:
373 try:
374 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
376 for gateway in gateways:
377 try:
378 # Check gateway health
379 await self._check_gateway_health(gateway)
380 except Exception as e:
381 logger.warning(f"Health check failed for {gateway.name}: {e}")
382 # Mark inactive if not seen recently
383 if datetime.utcnow() - gateway.last_seen > timedelta(minutes=5):
384 gateway.is_active = False
385 self._active_gateways.discard(gateway.url)
387 db.commit()
389 except Exception as e:
390 logger.error(f"Health check error: {e}")
391 db.rollback()
393 await asyncio.sleep(settings.health_check_interval)
395 async def _initialize_gateway(self, url: str) -> ServerCapabilities:
396 """Initialize connection to a gateway.
398 Args:
399 url: Gateway URL
401 Returns:
402 Gateway capabilities
404 Raises:
405 FederationError: If initialization fails
406 """
407 try:
408 # Build initialize request
409 request = InitializeRequest(
410 protocol_version=PROTOCOL_VERSION,
411 capabilities=ClientCapabilities(roots={"listChanged": True}, sampling={}),
412 client_info=Implementation(name=settings.app_name, version="1.0.0"),
413 )
415 # Send request using the persistent client directly
416 response = await self._http_client.post(
417 f"{url}/initialize",
418 json=request.dict(),
419 headers=self._get_auth_headers(),
420 )
421 response.raise_for_status()
422 result = InitializeResult.parse_obj(response.json())
424 # Verify protocol version
425 if result.protocol_version != PROTOCOL_VERSION:
426 raise FederationError(f"Unsupported protocol version: {result.protocol_version}")
428 return result.capabilities
430 except Exception as e:
431 raise FederationError(f"Failed to initialize gateway: {str(e)}")
433 async def _check_gateway_health(self, gateway: DbGateway) -> bool:
434 """Check if a gateway is healthy.
436 Args:
437 gateway: Gateway to check
439 Returns:
440 True if gateway is healthy
442 Raises:
443 FederationError: If health check fails
444 """
445 try:
446 await self._initialize_gateway(gateway.url)
447 return True
448 except Exception as e:
449 raise FederationError(f"Gateway health check failed: {str(e)}")
451 def _get_auth_headers(self) -> Dict[str, str]:
452 """
453 Get headers for gateway authentication.
455 Returns:
456 dict: Headers to be used in request
457 """
458 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
459 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}