Coverage for mcpgateway/federation/discovery.py: 44%

176 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:23 +0100

1# -*- coding: utf-8 -*- 

2"""Federation Discovery Service. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements automatic peer discovery for MCP Gateways. 

9It supports multiple discovery mechanisms: 

10- DNS-SD service discovery 

11- Static peer lists 

12- Peer exchange protocol 

13- Manual registration 

14""" 

15 

16import asyncio 

17import logging 

18import os 

19import socket 

20from dataclasses import dataclass 

21from datetime import datetime, timedelta 

22from typing import Dict, List, Optional 

23from urllib.parse import urlparse 

24 

25import httpx 

26from zeroconf import ServiceInfo, ServiceStateChange 

27from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf 

28 

29from mcpgateway.config import settings 

30from mcpgateway.types import ServerCapabilities 

31 

32logger = logging.getLogger(__name__) 

33 

34PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26") 

35 

36 

37@dataclass 

38class DiscoveredPeer: 

39 """Information about a discovered peer gateway.""" 

40 

41 url: str 

42 name: Optional[str] 

43 protocol_version: Optional[str] 

44 capabilities: Optional[ServerCapabilities] 

45 discovered_at: datetime 

46 last_seen: datetime 

47 source: str 

48 

49 

50class LocalDiscoveryService: 

51 """Super class for DiscoveryService""" 

52 

53 def __init__(self): 

54 """Initialize local discovery service""" 

55 # Service info for local discovery 

56 self._service_type = "_mcp._tcp.local." 

57 self._service_info = ServiceInfo( 

58 self._service_type, 

59 f"{settings.app_name}.{self._service_type}", 

60 addresses=[socket.inet_aton(addr) for addr in self._get_local_addresses()], 

61 port=settings.port, 

62 properties={ 

63 "name": settings.app_name, 

64 "version": "1.0.0", 

65 "protocol": PROTOCOL_VERSION, 

66 }, 

67 ) 

68 

69 def _get_local_addresses(self) -> List[str]: 

70 """Get list of local network addresses. 

71 

72 Returns: 

73 List of IP addresses 

74 """ 

75 addresses = [] 

76 try: 

77 # Get all network interfaces 

78 for iface in socket.getaddrinfo(socket.gethostname(), None): 

79 addr = iface[4][0] 

80 # Skip localhost 

81 if not addr.startswith("127."): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 addresses.append(addr) 

83 except Exception as e: 

84 logger.warning(f"Failed to get local addresses: {e}") 

85 # Fall back to localhost 

86 addresses.append("127.0.0.1") 

87 

88 return addresses or ["127.0.0.1"] 

89 

90 

91class DiscoveryService(LocalDiscoveryService): 

92 """Service for automatic gateway discovery. 

93 

94 Supports multiple discovery mechanisms: 

95 - DNS-SD for local network discovery 

96 - Static peer lists from configuration 

97 - Peer exchange with known gateways 

98 - Manual registration via API 

99 """ 

100 

101 def __init__(self): 

102 """Initialize discovery service.""" 

103 super().__init__() 

104 

105 self._zeroconf: Optional[AsyncZeroconf] = None 

106 self._browser: Optional[AsyncServiceBrowser] = None 

107 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

108 

109 # Track discovered peers 

110 self._discovered_peers: Dict[str, DiscoveredPeer] = {} 

111 

112 # Start background tasks 

113 self._cleanup_task: Optional[asyncio.Task] = None 

114 self._refresh_task: Optional[asyncio.Task] = None 

115 

116 async def start(self) -> None: 

117 """ 

118 Start discovery service. 

119 

120 Raises: 

121 Exception: If unable to start discovery service 

122 """ 

123 try: 

124 # Initialize DNS-SD 

125 if settings.federation_discovery: 

126 self._zeroconf = AsyncZeroconf() 

127 await self._zeroconf.async_register_service(self._service_info) 

128 self._browser = AsyncServiceBrowser( 

129 self._zeroconf.zeroconf, 

130 self._service_type, 

131 handlers=[self._on_service_state_change], 

132 ) 

133 

134 # Start background tasks 

135 self._cleanup_task = asyncio.create_task(self._cleanup_loop()) 

136 self._refresh_task = asyncio.create_task(self._refresh_loop()) 

137 

138 # Load static peers 

139 for peer_url in settings.federation_peers: 

140 await self.add_peer(peer_url, source="static") 

141 

142 logger.info("Discovery service started") 

143 

144 except Exception as e: 

145 logger.error(f"Failed to start discovery service: {e}") 

146 await self.stop() 

147 raise 

148 

149 async def stop(self) -> None: 

150 """Stop discovery service.""" 

151 # Cancel background tasks 

152 if self._cleanup_task: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 self._cleanup_task.cancel() 

154 try: 

155 await self._cleanup_task 

156 except asyncio.CancelledError: 

157 pass 

158 

159 if self._refresh_task: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 self._refresh_task.cancel() 

161 try: 

162 await self._refresh_task 

163 except asyncio.CancelledError: 

164 pass 

165 

166 # Stop DNS-SD 

167 if self._browser: 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true

168 await self._browser.async_cancel() 

169 self._browser = None 

170 

171 if self._zeroconf: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 await self._zeroconf.async_unregister_service(self._service_info) 

173 await self._zeroconf.async_close() 

174 self._zeroconf = None 

175 

176 # Close HTTP client 

177 await self._http_client.aclose() 

178 

179 logger.info("Discovery service stopped") 

180 

181 async def add_peer(self, url: str, source: str, name: Optional[str] = None) -> bool: 

182 """Add a new peer gateway. 

183 

184 Args: 

185 url: Gateway URL 

186 source: Discovery source 

187 name: Optional gateway name 

188 

189 Returns: 

190 True if peer was added 

191 """ 

192 # Validate URL 

193 try: 

194 parsed = urlparse(url) 

195 if not parsed.scheme or not parsed.netloc: 

196 logger.warning(f"Invalid peer URL: {url}") 

197 return False 

198 except Exception: 

199 logger.warning(f"Failed to parse peer URL: {url}") 

200 return False 

201 

202 # Skip if already known 

203 if url in self._discovered_peers: 

204 peer = self._discovered_peers[url] 

205 peer.last_seen = datetime.utcnow() 

206 return False 

207 

208 try: 

209 # Try to get gateway info 

210 capabilities = await self._get_gateway_info(url) 

211 

212 # Add to discovered peers 

213 self._discovered_peers[url] = DiscoveredPeer( 

214 url=url, 

215 name=name, 

216 protocol_version=PROTOCOL_VERSION, 

217 capabilities=capabilities, 

218 discovered_at=datetime.utcnow(), 

219 last_seen=datetime.utcnow(), 

220 source=source, 

221 ) 

222 

223 logger.info(f"Added peer gateway: {url} (via {source})") 

224 return True 

225 

226 except Exception as e: 

227 logger.warning(f"Failed to add peer {url}: {e}") 

228 return False 

229 

230 def get_discovered_peers(self) -> List[DiscoveredPeer]: 

231 """Get list of discovered peers. 

232 

233 Returns: 

234 List of discovered peer information 

235 """ 

236 return list(self._discovered_peers.values()) 

237 

238 async def refresh_peer(self, url: str) -> bool: 

239 """Refresh peer gateway information. 

240 

241 Args: 

242 url: Gateway URL to refresh 

243 

244 Returns: 

245 True if refresh succeeded 

246 """ 

247 if url not in self._discovered_peers: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 return False 

249 

250 try: 

251 capabilities = await self._get_gateway_info(url) 

252 self._discovered_peers[url].capabilities = capabilities 

253 self._discovered_peers[url].last_seen = datetime.utcnow() 

254 return True 

255 except Exception as e: 

256 logger.warning(f"Failed to refresh peer {url}: {e}") 

257 return False 

258 

259 async def remove_peer(self, url: str) -> None: 

260 """Remove a peer gateway. 

261 

262 Args: 

263 url: Gateway URL to remove 

264 """ 

265 self._discovered_peers.pop(url, None) 

266 

267 async def _on_service_state_change( 

268 self, 

269 zeroconf: AsyncZeroconf, 

270 service_type: str, 

271 name: str, 

272 state_change: ServiceStateChange, 

273 ) -> None: 

274 """Handle DNS-SD service changes. 

275 

276 Args: 

277 zeroconf: Zeroconf instance 

278 service_type: Service type 

279 name: Service name 

280 state_change: Type of state change 

281 """ 

282 if state_change is ServiceStateChange.Added: 

283 info = await zeroconf.async_get_service_info(service_type, name) 

284 if info: 

285 try: 

286 # Extract gateway info 

287 addresses = [socket.inet_ntoa(addr) for addr in info.addresses] 

288 if addresses: 

289 port = info.port 

290 url = f"http://{addresses[0]}:{port}" 

291 name = info.properties.get(b"name", b"").decode() 

292 

293 # Add peer 

294 await self.add_peer(url, source="dns-sd", name=name) 

295 

296 except Exception as e: 

297 logger.warning(f"Failed to process discovered service {name}: {e}") 

298 

299 async def _cleanup_loop(self) -> None: 

300 """Periodically clean up stale peers.""" 

301 while True: 

302 try: 

303 now = datetime.utcnow() 

304 stale_urls = [url for url, peer in self._discovered_peers.items() if now - peer.last_seen > timedelta(minutes=10)] 

305 for url in stale_urls: 

306 await self.remove_peer(url) 

307 logger.info(f"Removed stale peer: {url}") 

308 

309 except Exception as e: 

310 logger.error(f"Peer cleanup error: {e}") 

311 

312 await asyncio.sleep(60) 

313 

314 async def _refresh_loop(self) -> None: 

315 """Periodically refresh peer information.""" 

316 while True: 

317 try: 

318 # Refresh all peers 

319 for url in list(self._discovered_peers.keys()): 

320 await self.refresh_peer(url) 

321 

322 # Exchange peers 

323 await self._exchange_peers() 

324 

325 except Exception as e: 

326 logger.error(f"Peer refresh error: {e}") 

327 

328 await asyncio.sleep(300) # 5 minutes 

329 

330 async def _get_gateway_info(self, url: str) -> ServerCapabilities: 

331 """Get gateway capabilities. 

332 

333 Args: 

334 url: Gateway URL 

335 

336 Returns: 

337 Gateway capabilities 

338 

339 Raises: 

340 ValueError: If protocol version is unsupported 

341 """ 

342 # Build initialize request 

343 request = { 

344 "jsonrpc": "2.0", 

345 "id": 1, 

346 "method": "initialize", 

347 "params": { 

348 "protocol_version": PROTOCOL_VERSION, 

349 "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, 

350 "client_info": {"name": settings.app_name, "version": "1.0.0"}, 

351 }, 

352 } 

353 

354 # Send request using the persistent HTTP client directly 

355 response = await self._http_client.post(f"{url}/initialize", json=request, headers=self._get_auth_headers()) 

356 response.raise_for_status() 

357 result = response.json() 

358 

359 # Validate response 

360 if result.get("protocol_version") != PROTOCOL_VERSION: 

361 raise ValueError(f"Unsupported protocol version: {result.get('protocol_version')}") 

362 

363 return ServerCapabilities.parse_obj(result["capabilities"]) 

364 

365 async def _exchange_peers(self) -> None: 

366 """Exchange peer lists with known gateways.""" 

367 for url in list(self._discovered_peers.keys()): 

368 try: 

369 # Get peer's peer list using the persistent HTTP client directly 

370 response = await self._http_client.get(f"{url}/peers", headers=self._get_auth_headers()) 

371 response.raise_for_status() 

372 peers = response.json() 

373 

374 # Add new peers from the response 

375 for peer in peers: 

376 if isinstance(peer, dict) and "url" in peer: 

377 await self.add_peer(peer["url"], source="exchange", name=peer.get("name")) 

378 

379 except Exception as e: 

380 logger.warning(f"Failed to exchange peers with {url}: {e}") 

381 

382 def _get_auth_headers(self) -> Dict[str, str]: 

383 """ 

384 Get headers for gateway authentication. 

385 

386 Returns: 

387 dict: Authorization header dict 

388 """ 

389 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

390 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}