Coverage for mcpgateway/federation/discovery.py: 44%
176 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:23 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:23 +0100
1# -*- coding: utf-8 -*-
2"""Federation Discovery Service.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements automatic peer discovery for MCP Gateways.
9It supports multiple discovery mechanisms:
10- DNS-SD service discovery
11- Static peer lists
12- Peer exchange protocol
13- Manual registration
14"""
16import asyncio
17import logging
18import os
19import socket
20from dataclasses import dataclass
21from datetime import datetime, timedelta
22from typing import Dict, List, Optional
23from urllib.parse import urlparse
25import httpx
26from zeroconf import ServiceInfo, ServiceStateChange
27from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf
29from mcpgateway.config import settings
30from mcpgateway.types import ServerCapabilities
32logger = logging.getLogger(__name__)
34PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26")
37@dataclass
38class DiscoveredPeer:
39 """Information about a discovered peer gateway."""
41 url: str
42 name: Optional[str]
43 protocol_version: Optional[str]
44 capabilities: Optional[ServerCapabilities]
45 discovered_at: datetime
46 last_seen: datetime
47 source: str
50class LocalDiscoveryService:
51 """Super class for DiscoveryService"""
53 def __init__(self):
54 """Initialize local discovery service"""
55 # Service info for local discovery
56 self._service_type = "_mcp._tcp.local."
57 self._service_info = ServiceInfo(
58 self._service_type,
59 f"{settings.app_name}.{self._service_type}",
60 addresses=[socket.inet_aton(addr) for addr in self._get_local_addresses()],
61 port=settings.port,
62 properties={
63 "name": settings.app_name,
64 "version": "1.0.0",
65 "protocol": PROTOCOL_VERSION,
66 },
67 )
69 def _get_local_addresses(self) -> List[str]:
70 """Get list of local network addresses.
72 Returns:
73 List of IP addresses
74 """
75 addresses = []
76 try:
77 # Get all network interfaces
78 for iface in socket.getaddrinfo(socket.gethostname(), None):
79 addr = iface[4][0]
80 # Skip localhost
81 if not addr.startswith("127."): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 addresses.append(addr)
83 except Exception as e:
84 logger.warning(f"Failed to get local addresses: {e}")
85 # Fall back to localhost
86 addresses.append("127.0.0.1")
88 return addresses or ["127.0.0.1"]
91class DiscoveryService(LocalDiscoveryService):
92 """Service for automatic gateway discovery.
94 Supports multiple discovery mechanisms:
95 - DNS-SD for local network discovery
96 - Static peer lists from configuration
97 - Peer exchange with known gateways
98 - Manual registration via API
99 """
101 def __init__(self):
102 """Initialize discovery service."""
103 super().__init__()
105 self._zeroconf: Optional[AsyncZeroconf] = None
106 self._browser: Optional[AsyncServiceBrowser] = None
107 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
109 # Track discovered peers
110 self._discovered_peers: Dict[str, DiscoveredPeer] = {}
112 # Start background tasks
113 self._cleanup_task: Optional[asyncio.Task] = None
114 self._refresh_task: Optional[asyncio.Task] = None
116 async def start(self) -> None:
117 """
118 Start discovery service.
120 Raises:
121 Exception: If unable to start discovery service
122 """
123 try:
124 # Initialize DNS-SD
125 if settings.federation_discovery:
126 self._zeroconf = AsyncZeroconf()
127 await self._zeroconf.async_register_service(self._service_info)
128 self._browser = AsyncServiceBrowser(
129 self._zeroconf.zeroconf,
130 self._service_type,
131 handlers=[self._on_service_state_change],
132 )
134 # Start background tasks
135 self._cleanup_task = asyncio.create_task(self._cleanup_loop())
136 self._refresh_task = asyncio.create_task(self._refresh_loop())
138 # Load static peers
139 for peer_url in settings.federation_peers:
140 await self.add_peer(peer_url, source="static")
142 logger.info("Discovery service started")
144 except Exception as e:
145 logger.error(f"Failed to start discovery service: {e}")
146 await self.stop()
147 raise
149 async def stop(self) -> None:
150 """Stop discovery service."""
151 # Cancel background tasks
152 if self._cleanup_task: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 self._cleanup_task.cancel()
154 try:
155 await self._cleanup_task
156 except asyncio.CancelledError:
157 pass
159 if self._refresh_task: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 self._refresh_task.cancel()
161 try:
162 await self._refresh_task
163 except asyncio.CancelledError:
164 pass
166 # Stop DNS-SD
167 if self._browser: 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true
168 await self._browser.async_cancel()
169 self._browser = None
171 if self._zeroconf: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 await self._zeroconf.async_unregister_service(self._service_info)
173 await self._zeroconf.async_close()
174 self._zeroconf = None
176 # Close HTTP client
177 await self._http_client.aclose()
179 logger.info("Discovery service stopped")
181 async def add_peer(self, url: str, source: str, name: Optional[str] = None) -> bool:
182 """Add a new peer gateway.
184 Args:
185 url: Gateway URL
186 source: Discovery source
187 name: Optional gateway name
189 Returns:
190 True if peer was added
191 """
192 # Validate URL
193 try:
194 parsed = urlparse(url)
195 if not parsed.scheme or not parsed.netloc:
196 logger.warning(f"Invalid peer URL: {url}")
197 return False
198 except Exception:
199 logger.warning(f"Failed to parse peer URL: {url}")
200 return False
202 # Skip if already known
203 if url in self._discovered_peers:
204 peer = self._discovered_peers[url]
205 peer.last_seen = datetime.utcnow()
206 return False
208 try:
209 # Try to get gateway info
210 capabilities = await self._get_gateway_info(url)
212 # Add to discovered peers
213 self._discovered_peers[url] = DiscoveredPeer(
214 url=url,
215 name=name,
216 protocol_version=PROTOCOL_VERSION,
217 capabilities=capabilities,
218 discovered_at=datetime.utcnow(),
219 last_seen=datetime.utcnow(),
220 source=source,
221 )
223 logger.info(f"Added peer gateway: {url} (via {source})")
224 return True
226 except Exception as e:
227 logger.warning(f"Failed to add peer {url}: {e}")
228 return False
230 def get_discovered_peers(self) -> List[DiscoveredPeer]:
231 """Get list of discovered peers.
233 Returns:
234 List of discovered peer information
235 """
236 return list(self._discovered_peers.values())
238 async def refresh_peer(self, url: str) -> bool:
239 """Refresh peer gateway information.
241 Args:
242 url: Gateway URL to refresh
244 Returns:
245 True if refresh succeeded
246 """
247 if url not in self._discovered_peers: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 return False
250 try:
251 capabilities = await self._get_gateway_info(url)
252 self._discovered_peers[url].capabilities = capabilities
253 self._discovered_peers[url].last_seen = datetime.utcnow()
254 return True
255 except Exception as e:
256 logger.warning(f"Failed to refresh peer {url}: {e}")
257 return False
259 async def remove_peer(self, url: str) -> None:
260 """Remove a peer gateway.
262 Args:
263 url: Gateway URL to remove
264 """
265 self._discovered_peers.pop(url, None)
267 async def _on_service_state_change(
268 self,
269 zeroconf: AsyncZeroconf,
270 service_type: str,
271 name: str,
272 state_change: ServiceStateChange,
273 ) -> None:
274 """Handle DNS-SD service changes.
276 Args:
277 zeroconf: Zeroconf instance
278 service_type: Service type
279 name: Service name
280 state_change: Type of state change
281 """
282 if state_change is ServiceStateChange.Added:
283 info = await zeroconf.async_get_service_info(service_type, name)
284 if info:
285 try:
286 # Extract gateway info
287 addresses = [socket.inet_ntoa(addr) for addr in info.addresses]
288 if addresses:
289 port = info.port
290 url = f"http://{addresses[0]}:{port}"
291 name = info.properties.get(b"name", b"").decode()
293 # Add peer
294 await self.add_peer(url, source="dns-sd", name=name)
296 except Exception as e:
297 logger.warning(f"Failed to process discovered service {name}: {e}")
299 async def _cleanup_loop(self) -> None:
300 """Periodically clean up stale peers."""
301 while True:
302 try:
303 now = datetime.utcnow()
304 stale_urls = [url for url, peer in self._discovered_peers.items() if now - peer.last_seen > timedelta(minutes=10)]
305 for url in stale_urls:
306 await self.remove_peer(url)
307 logger.info(f"Removed stale peer: {url}")
309 except Exception as e:
310 logger.error(f"Peer cleanup error: {e}")
312 await asyncio.sleep(60)
314 async def _refresh_loop(self) -> None:
315 """Periodically refresh peer information."""
316 while True:
317 try:
318 # Refresh all peers
319 for url in list(self._discovered_peers.keys()):
320 await self.refresh_peer(url)
322 # Exchange peers
323 await self._exchange_peers()
325 except Exception as e:
326 logger.error(f"Peer refresh error: {e}")
328 await asyncio.sleep(300) # 5 minutes
330 async def _get_gateway_info(self, url: str) -> ServerCapabilities:
331 """Get gateway capabilities.
333 Args:
334 url: Gateway URL
336 Returns:
337 Gateway capabilities
339 Raises:
340 ValueError: If protocol version is unsupported
341 """
342 # Build initialize request
343 request = {
344 "jsonrpc": "2.0",
345 "id": 1,
346 "method": "initialize",
347 "params": {
348 "protocol_version": PROTOCOL_VERSION,
349 "capabilities": {"roots": {"listChanged": True}, "sampling": {}},
350 "client_info": {"name": settings.app_name, "version": "1.0.0"},
351 },
352 }
354 # Send request using the persistent HTTP client directly
355 response = await self._http_client.post(f"{url}/initialize", json=request, headers=self._get_auth_headers())
356 response.raise_for_status()
357 result = response.json()
359 # Validate response
360 if result.get("protocol_version") != PROTOCOL_VERSION:
361 raise ValueError(f"Unsupported protocol version: {result.get('protocol_version')}")
363 return ServerCapabilities.parse_obj(result["capabilities"])
365 async def _exchange_peers(self) -> None:
366 """Exchange peer lists with known gateways."""
367 for url in list(self._discovered_peers.keys()):
368 try:
369 # Get peer's peer list using the persistent HTTP client directly
370 response = await self._http_client.get(f"{url}/peers", headers=self._get_auth_headers())
371 response.raise_for_status()
372 peers = response.json()
374 # Add new peers from the response
375 for peer in peers:
376 if isinstance(peer, dict) and "url" in peer:
377 await self.add_peer(peer["url"], source="exchange", name=peer.get("name"))
379 except Exception as e:
380 logger.warning(f"Failed to exchange peers with {url}: {e}")
382 def _get_auth_headers(self) -> Dict[str, str]:
383 """
384 Get headers for gateway authentication.
386 Returns:
387 dict: Authorization header dict
388 """
389 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
390 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}