Coverage for mcpgateway/services/gateway_service.py: 46%

389 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:47 +0100

1# -*- coding: utf-8 -*- 

2"""Gateway Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements gateway federation according to the MCP specification. 

9It handles: 

10- Gateway discovery and registration 

11- Request forwarding 

12- Capability aggregation 

13- Health monitoring 

14- Active/inactive gateway management 

15""" 

16 

17import asyncio 

18import logging 

19import uuid 

20from datetime import datetime, timezone 

21from typing import Any, AsyncGenerator, Dict, List, Optional, Set 

22 

23import httpx 

24from filelock import FileLock, Timeout 

25from mcp import ClientSession 

26from mcp.client.sse import sse_client 

27from mcp.client.streamable_http import streamablehttp_client 

28from sqlalchemy import select 

29from sqlalchemy.orm import Session 

30 

31from mcpgateway.config import settings 

32from mcpgateway.db import Gateway as DbGateway 

33from mcpgateway.db import SessionLocal 

34from mcpgateway.db import Tool as DbTool 

35from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, ToolCreate 

36from mcpgateway.services.tool_service import ToolService 

37from mcpgateway.utils.services_auth import decode_auth 

38 

39try: 

40 import redis 

41 

42 REDIS_AVAILABLE = True 

43except ImportError: 

44 REDIS_AVAILABLE = False 

45 logging.info("Redis is not utilized in this environment.") 

46 

47# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks 

48logger = logging.getLogger(__name__) 

49 

50 

51GW_FAILURE_THRESHOLD = settings.unhealthy_threshold 

52GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval 

53 

54 

55class GatewayError(Exception): 

56 """Base class for gateway-related errors.""" 

57 

58 

59class GatewayNotFoundError(GatewayError): 

60 """Raised when a requested gateway is not found.""" 

61 

62 

63class GatewayNameConflictError(GatewayError): 

64 """Raised when a gateway name conflicts with existing (active or inactive) gateway.""" 

65 

66 def __init__(self, name: str, is_active: bool = True, gateway_id: Optional[int] = None): 

67 """Initialize the error with gateway information. 

68 

69 Args: 

70 name: The conflicting gateway name 

71 is_active: Whether the existing gateway is active 

72 gateway_id: ID of the existing gateway if available 

73 """ 

74 self.name = name 

75 self.is_active = is_active 

76 self.gateway_id = gateway_id 

77 message = f"Gateway already exists with name: {name}" 

78 if not is_active: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 message += f" (currently inactive, ID: {gateway_id})" 

80 super().__init__(message) 

81 

82 

83class GatewayConnectionError(GatewayError): 

84 """Raised when gateway connection fails.""" 

85 

86 

87class GatewayService: 

88 """Service for managing federated gateways. 

89 

90 Handles: 

91 - Gateway registration and health checks 

92 - Request forwarding 

93 - Capability negotiation 

94 - Federation events 

95 - Active/inactive status management 

96 """ 

97 

98 def __init__(self): 

99 """Initialize the gateway service.""" 

100 self._event_subscribers: List[asyncio.Queue] = [] 

101 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

102 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL 

103 self._health_check_task: Optional[asyncio.Task] = None 

104 self._active_gateways: Set[str] = set() # Track active gateway URLs 

105 self._stream_response = None 

106 self._pending_responses = {} 

107 self.tool_service = ToolService() 

108 self._gateway_failure_counts: dict[str, int] = {} 

109 

110 # For health checks, we determine the leader instance. 

111 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None 

112 

113 if self.redis_url and REDIS_AVAILABLE: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 self._redis_client = redis.from_url(self.redis_url) 

115 self._instance_id = str(uuid.uuid4()) # Unique ID for this process 

116 self._leader_key = "gateway_service_leader" 

117 self._leader_ttl = 40 # seconds 

118 elif settings.cache_type != "none": 118 ↛ 124line 118 didn't jump to line 124 because the condition on line 118 was always true

119 # Fallback: File-based lock 

120 self._redis_client = None 

121 self._lock_path = settings.filelock_path 

122 self._file_lock = FileLock(self._lock_path) 

123 else: 

124 self._redis_client = None 

125 

126 async def initialize(self) -> None: 

127 """Initialize the service and start health check if this instance is the leader. 

128 

129 Raises: 

130 ConnectionError: When redis ping fails 

131 """ 

132 logger.info("Initializing gateway service") 

133 

134 if self._redis_client: 

135 # Check if Redis is available 

136 pong = self._redis_client.ping() 

137 if not pong: 

138 raise ConnectionError("Redis ping failed.") 

139 

140 is_leader = self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

141 if is_leader: 

142 logger.info("Acquired Redis leadership. Starting health check task.") 

143 self._health_check_task = asyncio.create_task(self._run_health_checks()) 

144 else: 

145 # Always create the health check task in filelock mode; leader check is handled inside. 

146 self._health_check_task = asyncio.create_task(self._run_health_checks()) 

147 

148 async def shutdown(self) -> None: 

149 """Shutdown the service.""" 

150 if self._health_check_task: 

151 self._health_check_task.cancel() 

152 try: 

153 await self._health_check_task 

154 except asyncio.CancelledError: 

155 pass 

156 

157 await self._http_client.aclose() 

158 self._event_subscribers.clear() 

159 self._active_gateways.clear() 

160 logger.info("Gateway service shutdown complete") 

161 

162 async def register_gateway(self, db: Session, gateway: GatewayCreate) -> GatewayRead: 

163 """Register a new gateway. 

164 

165 Args: 

166 db: Database session 

167 gateway: Gateway creation schema 

168 

169 Returns: 

170 Created gateway information 

171 

172 Raises: 

173 GatewayNameConflictError: If gateway name already exists 

174 []: When ExceptionGroup found 

175 """ 

176 try: 

177 # Check for name conflicts (both active and inactive) 

178 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none() 

179 

180 if existing_gateway: 

181 raise GatewayNameConflictError( 

182 gateway.name, 

183 is_active=existing_gateway.is_active, 

184 gateway_id=existing_gateway.id, 

185 ) 

186 

187 auth_type = getattr(gateway, "auth_type", None) 

188 auth_value = getattr(gateway, "auth_value", {}) 

189 

190 capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value, gateway.transport) 

191 

192 all_names = [td.name for td in tools] 

193 

194 existing_tools = db.execute(select(DbTool).where(DbTool.name.in_(all_names))).scalars().all() 

195 existing_tool_names = [tool.name for tool in existing_tools] 

196 

197 tools = [ 

198 DbTool( 

199 name=tool.name, 

200 url=str(gateway.url), 

201 description=tool.description, 

202 integration_type=tool.integration_type, 

203 request_type=tool.request_type, 

204 headers=tool.headers, 

205 input_schema=tool.input_schema, 

206 jsonpath_filter=tool.jsonpath_filter, 

207 auth_type=auth_type, 

208 auth_value=auth_value, 

209 ) 

210 for tool in tools 

211 ] 

212 

213 existing_tools = [tool for tool in tools if tool.name in existing_tool_names] 

214 new_tools = [tool for tool in tools if tool.name not in existing_tool_names] 

215 

216 # Create DB model 

217 db_gateway = DbGateway( 

218 name=gateway.name, 

219 url=str(gateway.url), 

220 description=gateway.description, 

221 transport=gateway.transport, 

222 capabilities=capabilities, 

223 last_seen=datetime.now(timezone.utc), 

224 auth_type=auth_type, 

225 auth_value=auth_value, 

226 tools=new_tools, 

227 # federated_tools=existing_tools + new_tools 

228 ) 

229 

230 # Add to DB 

231 db.add(db_gateway) 

232 db.commit() 

233 db.refresh(db_gateway) 

234 

235 # Update tracking 

236 self._active_gateways.add(db_gateway.url) 

237 

238 # Notify subscribers 

239 await self._notify_gateway_added(db_gateway) 

240 

241 return GatewayRead.model_validate(gateway) 

242 except* GatewayConnectionError as ge: 

243 logger.error("GatewayConnectionError in group: %s", ge.exceptions) 

244 raise ge.exceptions[0] 

245 except* ValueError as ve: 

246 logger.error("ValueErrors in group: %s", ve.exceptions) 

247 raise ve.exceptions[0] 

248 except* RuntimeError as re: 

249 logger.error("RuntimeErrors in group: %s", re.exceptions) 

250 raise re.exceptions[0] 

251 except* BaseException as other: # catches every other sub-exception 

252 logger.error("Other grouped errors: %s", other.exceptions) 

253 raise other.exceptions[0] 

254 

255 async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]: 

256 """List all registered gateways. 

257 

258 Args: 

259 db: Database session 

260 include_inactive: Whether to include inactive gateways 

261 

262 Returns: 

263 List of registered gateways 

264 """ 

265 query = select(DbGateway) 

266 

267 if not include_inactive: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true

268 query = query.where(DbGateway.is_active) 

269 

270 gateways = db.execute(query).scalars().all() 

271 return [GatewayRead.model_validate(g) for g in gateways] 

272 

273 async def update_gateway(self, db: Session, gateway_id: int, gateway_update: GatewayUpdate) -> GatewayRead: 

274 """Update a gateway. 

275 

276 Args: 

277 db: Database session 

278 gateway_id: Gateway ID to update 

279 gateway_update: Updated gateway data 

280 

281 Returns: 

282 Updated gateway information 

283 

284 Raises: 

285 GatewayNotFoundError: If gateway not found 

286 GatewayError: For other update errors 

287 GatewayNameConflictError: If gateway name conflict occurs 

288 """ 

289 try: 

290 # Find gateway 

291 gateway = db.get(DbGateway, gateway_id) 

292 if not gateway: 

293 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

294 

295 if not gateway.is_active: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true

296 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive") 

297 

298 # Check for name conflicts if name is being changed 

299 if gateway_update.name is not None and gateway_update.name != gateway.name: 299 ↛ 310line 299 didn't jump to line 310 because the condition on line 299 was always true

300 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none() 

301 

302 if existing_gateway: 

303 raise GatewayNameConflictError( 

304 gateway_update.name, 

305 is_active=existing_gateway.is_active, 

306 gateway_id=existing_gateway.id, 

307 ) 

308 

309 # Update fields if provided 

310 if gateway_update.name is not None: 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true

311 gateway.name = gateway_update.name 

312 if gateway_update.url is not None: 312 ↛ 314line 312 didn't jump to line 314 because the condition on line 312 was always true

313 gateway.url = str(gateway_update.url) 

314 if gateway_update.description is not None: 314 ↛ 316line 314 didn't jump to line 316 because the condition on line 314 was always true

315 gateway.description = gateway_update.description 

316 if gateway_update.transport is not None: 316 ↛ 319line 316 didn't jump to line 319 because the condition on line 316 was always true

317 gateway.transport = gateway_update.transport 

318 

319 if getattr(gateway, "auth_type", None) is not None: 319 ↛ 327line 319 didn't jump to line 327 because the condition on line 319 was always true

320 gateway.auth_type = gateway_update.auth_type 

321 

322 # if auth_type is not None and only then check auth_value 

323 if getattr(gateway, "auth_value", {}) != {}: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true

324 gateway.auth_value = gateway_update.auth_value 

325 

326 # Try to reinitialize connection if URL changed 

327 if gateway_update.url is not None: 327 ↛ 339line 327 didn't jump to line 339 because the condition on line 327 was always true

328 try: 

329 capabilities, _ = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) 

330 gateway.capabilities = capabilities 

331 gateway.last_seen = datetime.utcnow() 

332 

333 # Update tracking with new URL 

334 self._active_gateways.discard(gateway.url) 

335 self._active_gateways.add(gateway.url) 

336 except Exception as e: 

337 logger.warning(f"Failed to initialize updated gateway: {e}") 

338 

339 gateway.updated_at = datetime.utcnow() 

340 db.commit() 

341 db.refresh(gateway) 

342 

343 # Notify subscribers 

344 await self._notify_gateway_updated(gateway) 

345 

346 logger.info(f"Updated gateway: {gateway.name}") 

347 return GatewayRead.model_validate(gateway) 

348 

349 except Exception as e: 

350 db.rollback() 

351 raise GatewayError(f"Failed to update gateway: {str(e)}") 

352 

353 async def get_gateway(self, db: Session, gateway_id: int, include_inactive: bool = False) -> GatewayRead: 

354 """Get a specific gateway by ID. 

355 

356 Args: 

357 db: Database session 

358 gateway_id: Gateway ID 

359 include_inactive: Whether to include inactive gateways 

360 

361 Returns: 

362 Gateway information 

363 

364 Raises: 

365 GatewayNotFoundError: If gateway not found 

366 """ 

367 gateway = db.get(DbGateway, gateway_id) 

368 if not gateway: 

369 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

370 

371 if not gateway.is_active and not include_inactive: 

372 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive") 

373 

374 return GatewayRead.model_validate(gateway) 

375 

376 async def toggle_gateway_status(self, db: Session, gateway_id: int, activate: bool) -> GatewayRead: 

377 """Toggle gateway active status. 

378 

379 Args: 

380 db: Database session 

381 gateway_id: Gateway ID to toggle 

382 activate: True to activate, False to deactivate 

383 

384 Returns: 

385 Updated gateway information 

386 

387 Raises: 

388 GatewayNotFoundError: If gateway not found 

389 GatewayError: For other errors 

390 """ 

391 try: 

392 gateway = db.get(DbGateway, gateway_id) 

393 if not gateway: 393 ↛ 394line 393 didn't jump to line 394 because the condition on line 393 was never true

394 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

395 

396 # Update status if it's different 

397 if gateway.is_active != activate: 397 ↛ 429line 397 didn't jump to line 429 because the condition on line 397 was always true

398 gateway.is_active = activate 

399 gateway.updated_at = datetime.utcnow() 

400 

401 # Update tracking 

402 if activate: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true

403 self._active_gateways.add(gateway.url) 

404 # Try to initialize if activating 

405 try: 

406 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) 

407 gateway.capabilities = capabilities.dict() 

408 gateway.last_seen = datetime.utcnow() 

409 except Exception as e: 

410 logger.warning(f"Failed to initialize reactivated gateway: {e}") 

411 else: 

412 self._active_gateways.discard(gateway.url) 

413 

414 db.commit() 

415 db.refresh(gateway) 

416 

417 tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all() 

418 for tool in tools: 

419 await self.tool_service.toggle_tool_status(db, tool.id, activate) 

420 

421 # Notify subscribers 

422 if activate: 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true

423 await self._notify_gateway_activated(gateway) 

424 else: 

425 await self._notify_gateway_deactivated(gateway) 

426 

427 logger.info(f"Gateway {gateway.name} {'activated' if activate else 'deactivated'}") 

428 

429 return GatewayRead.model_validate(gateway) 

430 

431 except Exception as e: 

432 db.rollback() 

433 raise GatewayError(f"Failed to toggle gateway status: {str(e)}") 

434 

435 async def _notify_gateway_updated(self, gateway: DbGateway) -> None: 

436 """ 

437 Notify subscribers of gateway update. 

438 

439 Args: 

440 gateway: Gateway to update 

441 """ 

442 event = { 

443 "type": "gateway_updated", 

444 "data": { 

445 "id": gateway.id, 

446 "name": gateway.name, 

447 "url": gateway.url, 

448 "description": gateway.description, 

449 "is_active": gateway.is_active, 

450 }, 

451 "timestamp": datetime.utcnow().isoformat(), 

452 } 

453 await self._publish_event(event) 

454 

455 async def delete_gateway(self, db: Session, gateway_id: int) -> None: 

456 """Permanently delete a gateway. 

457 

458 Args: 

459 db: Database session 

460 gateway_id: Gateway ID to delete 

461 

462 Raises: 

463 GatewayNotFoundError: If gateway not found 

464 GatewayError: For other deletion errors 

465 """ 

466 try: 

467 # Find gateway 

468 gateway = db.get(DbGateway, gateway_id) 

469 if not gateway: 

470 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

471 

472 # Store gateway info for notification before deletion 

473 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} 

474 

475 # Hard delete gateway 

476 db.delete(gateway) 

477 db.commit() 

478 

479 # Update tracking 

480 self._active_gateways.discard(gateway.url) 

481 

482 # Notify subscribers 

483 await self._notify_gateway_deleted(gateway_info) 

484 

485 logger.info(f"Permanently deleted gateway: {gateway.name}") 

486 

487 except Exception as e: 

488 db.rollback() 

489 raise GatewayError(f"Failed to delete gateway: {str(e)}") 

490 

491 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any: 

492 """Forward a request to a gateway. 

493 

494 Args: 

495 gateway: Gateway to forward to 

496 method: RPC method name 

497 params: Optional method parameters 

498 

499 Returns: 

500 Gateway response 

501 

502 Raises: 

503 GatewayConnectionError: If forwarding fails 

504 GatewayError: If gateway gave an error 

505 """ 

506 if not gateway.is_active: 506 ↛ 507line 506 didn't jump to line 507 because the condition on line 506 was never true

507 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}") 

508 

509 try: 

510 # Build RPC request 

511 request = {"jsonrpc": "2.0", "id": 1, "method": method} 

512 if params: 

513 request["params"] = params 

514 

515 # Directly use the persistent HTTP client (no async with) 

516 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers()) 

517 response.raise_for_status() 

518 result = response.json() 

519 

520 # Update last seen timestamp 

521 gateway.last_seen = datetime.utcnow() 

522 

523 if "error" in result: 

524 raise GatewayError(f"Gateway error: {result['error'].get('message')}") 

525 return result.get("result") 

526 

527 except Exception as e: 

528 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}") 

529 

530 async def _handle_gateway_failure(self, gateway: str) -> None: 

531 """ 

532 Tracks and handles gateway failures during health checks. 

533 If the failure count exceeds the threshold, the gateway is deactivated. 

534 

535 Args: 

536 gateway (str): The gateway object that failed its health check. 

537 

538 Returns: 

539 None 

540 """ 

541 if GW_FAILURE_THRESHOLD == -1: 

542 return # Gateway failure action disabled 

543 count = self._gateway_failure_counts.get(gateway.id, 0) + 1 

544 self._gateway_failure_counts[gateway.id] = count 

545 

546 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).") 

547 

548 if count >= GW_FAILURE_THRESHOLD: 

549 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") 

550 with SessionLocal() as db: 

551 await self.toggle_gateway_status(db, gateway.id, False) 

552 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation 

553 

554 async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool: 

555 """Health check for a list of gateways. 

556 

557 Deactivates gateway if gateway is not healthy. 

558 

559 Args: 

560 gateways (List[DbGateway]): List of gateways to check if healthy 

561 

562 Returns: 

563 bool: True if all active gateways are healthy 

564 """ 

565 # Reuse a single HTTP client for all requests 

566 async with httpx.AsyncClient() as client: 

567 for gateway in gateways: 

568 # Inactive gateways are unhealthy 

569 if not gateway.is_active: 

570 continue 

571 

572 try: 

573 # Ensure auth_value is a dict 

574 auth_data = gateway.auth_value or {} 

575 headers = decode_auth(auth_data) 

576 

577 # Perform the GET and raise on 4xx/5xx 

578 if (gateway.transport).lower() == "sse": 

579 timeout = httpx.Timeout(settings.health_check_timeout) 

580 async with client.stream("GET", gateway.url, headers=headers, timeout=timeout) as response: 

581 # This will raise immediately if status is 4xx/5xx 

582 response.raise_for_status() 

583 elif (gateway.transport).lower() == "streamablehttp": 

584 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, get_session_id): 

585 async with ClientSession(read_stream, write_stream) as session: 

586 # Initialize the session 

587 response = await session.initialize() 

588 

589 # Mark successful check 

590 gateway.last_seen = datetime.utcnow() 

591 

592 except Exception as e: 

593 await self._handle_gateway_failure(gateway) 

594 

595 # All gateways passed 

596 return True 

597 

598 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]: 

599 """Aggregate capabilities from all gateways. 

600 

601 Args: 

602 db: Database session 

603 

604 Returns: 

605 Combined capabilities 

606 """ 

607 capabilities = { 

608 "prompts": {"listChanged": True}, 

609 "resources": {"subscribe": True, "listChanged": True}, 

610 "tools": {"listChanged": True}, 

611 "logging": {}, 

612 } 

613 

614 # Get all active gateways 

615 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

616 

617 # Combine capabilities 

618 for gateway in gateways: 

619 if gateway.capabilities: 

620 for key, value in gateway.capabilities.items(): 

621 if key not in capabilities: 

622 capabilities[key] = value 

623 elif isinstance(value, dict): 

624 capabilities[key].update(value) 

625 

626 return capabilities 

627 

628 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

629 """Subscribe to gateway events. 

630 

631 Yields: 

632 Gateway event messages 

633 """ 

634 queue: asyncio.Queue = asyncio.Queue() 

635 self._event_subscribers.append(queue) 

636 try: 

637 while True: 

638 event = await queue.get() 

639 yield event 

640 finally: 

641 self._event_subscribers.remove(queue) 

642 

643 async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE") -> Any: 

644 """Initialize connection to a gateway and retrieve its capabilities. 

645 

646 Args: 

647 url: Gateway URL 

648 authentication: Optional authentication headers 

649 

650 Returns: 

651 Capabilities dictionary as provided by the gateway. 

652 

653 Raises: 

654 GatewayConnectionError: If initialization fails. 

655 """ 

656 try: 

657 if authentication is None: 

658 authentication = {} 

659 

660 async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None): 

661 """ 

662 Connect to an MCP server running with SSE transport 

663 

664 Args: 

665 server_url: URL to connect to the server 

666 authentication: Authentication headers for connection to URL 

667 

668 Returns: 

669 list, list: List of capabilities and tools 

670 """ 

671 if authentication is None: 

672 authentication = {} 

673 # Store the context managers so they stay alive 

674 decoded_auth = decode_auth(authentication) 

675 

676 # Use async with for both sse_client and ClientSession 

677 async with sse_client(url=server_url, headers=decoded_auth) as streams: 

678 async with ClientSession(*streams) as session: 

679 # Initialize the session 

680 response = await session.initialize() 

681 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

682 

683 response = await session.list_tools() 

684 tools = response.tools 

685 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] 

686 tools = [ToolCreate.model_validate(tool) for tool in tools] 

687 

688 return capabilities, tools 

689 

690 async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None): 

691 """ 

692 Connect to an MCP server running with Streamable HTTP transport 

693 

694 Args: 

695 server_url: URL to connect to the server 

696 authentication: Authentication headers for connection to URL 

697 

698 Returns: 

699 list, list: List of capabilities and tools 

700 """ 

701 if authentication is None: 

702 authentication = {} 

703 # Store the context managers so they stay alive 

704 decoded_auth = decode_auth(authentication) 

705 

706 # Use async with for both streamablehttp_client and ClientSession 

707 async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, get_session_id): 

708 async with ClientSession(read_stream, write_stream) as session: 

709 # Initialize the session 

710 response = await session.initialize() 

711 # if get_session_id: 

712 # session_id = get_session_id() 

713 # if session_id: 

714 # print(f"Session ID: {session_id}") 

715 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

716 response = await session.list_tools() 

717 tools = response.tools 

718 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] 

719 tools = [ToolCreate.model_validate(tool) for tool in tools] 

720 for tool in tools: 

721 tool.request_type = "STREAMABLEHTTP" 

722 

723 return capabilities, tools 

724 

725 if transport.lower() == "sse": 

726 capabilities, tools = await connect_to_sse_server(url, authentication) 

727 elif transport.lower() == "streamablehttp": 

728 capabilities, tools = await connect_to_streamablehttp_server(url, authentication) 

729 

730 return capabilities, tools 

731 except Exception as e: 

732 raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}") 

733 

734 def _get_active_gateways(self) -> list[DbGateway]: 

735 """Sync function for database operations (runs in thread). 

736 

737 Returns: 

738 List[DbGateway]: List of active gateways 

739 """ 

740 with SessionLocal() as db: 

741 return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all() 

742 

743 async def _run_health_checks(self) -> None: 

744 """Run health checks periodically, 

745 Uses Redis or FileLock - for multiple workers. 

746 Uses simple health check for single worker mode.""" 

747 

748 while True: 

749 try: 

750 if self._redis_client and settings.cache_type == "redis": 

751 # Redis-based leader check 

752 current_leader = self._redis_client.get(self._leader_key) 

753 if current_leader != self._instance_id.encode(): 

754 return 

755 self._redis_client.expire(self._leader_key, self._leader_ttl) 

756 

757 # Run health checks 

758 gateways = await asyncio.to_thread(self._get_active_gateways) 

759 if gateways: 

760 await self.check_health_of_gateways(gateways) 

761 

762 await asyncio.sleep(self._health_check_interval) 

763 

764 elif settings.cache_type == "none": 

765 try: 

766 # For single worker mode, run health checks directly 

767 gateways = await asyncio.to_thread(self._get_active_gateways) 

768 

769 if gateways: 

770 await self.check_health_of_gateways(gateways) 

771 except Exception as e: 

772 logger.error(f"Health check run failed: {str(e)}") 

773 

774 await asyncio.sleep(self._health_check_interval) 

775 

776 else: 

777 # FileLock-based leader fallback 

778 try: 

779 self._file_lock.acquire(timeout=0) 

780 logger.info("File lock acquired. Running health checks.") 

781 

782 while True: 

783 gateways = await asyncio.to_thread(self._get_active_gateways) 

784 if gateways: 

785 await self.check_health_of_gateways(gateways) 

786 await asyncio.sleep(self._health_check_interval) 

787 

788 except Timeout: 

789 logger.debug("File lock already held. Retrying later.") 

790 await asyncio.sleep(self._health_check_interval) 

791 

792 except Exception as e: 

793 logger.error(f"FileLock health check failed: {str(e)}") 

794 

795 finally: 

796 if self._file_lock.is_locked: 

797 try: 

798 self._file_lock.release() 

799 logger.info("Released file lock.") 

800 except Exception as e: 

801 logger.warning(f"Failed to release file lock: {str(e)}") 

802 

803 except Exception as e: 

804 logger.error(f"Unexpected error in health check loop: {str(e)}") 

805 await asyncio.sleep(self._health_check_interval) 

806 

807 def _get_auth_headers(self) -> Dict[str, str]: 

808 """ 

809 Get headers for gateway authentication. 

810 

811 Returns: 

812 dict: Authorization header dict 

813 """ 

814 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

815 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key, "Content-Type": "application/json"} 

816 

817 async def _notify_gateway_added(self, gateway: DbGateway) -> None: 

818 """ 

819 Notify subscribers of gateway addition. 

820 

821 Args: 

822 gateway: Gateway to add 

823 """ 

824 event = { 

825 "type": "gateway_added", 

826 "data": { 

827 "id": gateway.id, 

828 "name": gateway.name, 

829 "url": gateway.url, 

830 "description": gateway.description, 

831 "is_active": gateway.is_active, 

832 }, 

833 "timestamp": datetime.utcnow().isoformat(), 

834 } 

835 await self._publish_event(event) 

836 

837 async def _notify_gateway_activated(self, gateway: DbGateway) -> None: 

838 """ 

839 Notify subscribers of gateway activation. 

840 

841 Args: 

842 gateway: Gateway to activate 

843 """ 

844 event = { 

845 "type": "gateway_activated", 

846 "data": { 

847 "id": gateway.id, 

848 "name": gateway.name, 

849 "url": gateway.url, 

850 "is_active": True, 

851 }, 

852 "timestamp": datetime.utcnow().isoformat(), 

853 } 

854 await self._publish_event(event) 

855 

856 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None: 

857 """ 

858 Notify subscribers of gateway deactivation. 

859 

860 Args: 

861 gateway: Gateway database object 

862 """ 

863 event = { 

864 "type": "gateway_deactivated", 

865 "data": { 

866 "id": gateway.id, 

867 "name": gateway.name, 

868 "url": gateway.url, 

869 "is_active": False, 

870 }, 

871 "timestamp": datetime.utcnow().isoformat(), 

872 } 

873 await self._publish_event(event) 

874 

875 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None: 

876 """ 

877 Notify subscribers of gateway deletion. 

878 

879 Args: 

880 gateway_info: Dict containing information about gateway to delete 

881 """ 

882 event = { 

883 "type": "gateway_deleted", 

884 "data": gateway_info, 

885 "timestamp": datetime.utcnow().isoformat(), 

886 } 

887 await self._publish_event(event) 

888 

889 async def _notify_gateway_removed(self, gateway: DbGateway) -> None: 

890 """ 

891 Notify subscribers of gateway removal (deactivation). 

892 

893 Args: 

894 gateway: Gateway to remove 

895 """ 

896 event = { 

897 "type": "gateway_removed", 

898 "data": {"id": gateway.id, "name": gateway.name, "is_active": False}, 

899 "timestamp": datetime.utcnow().isoformat(), 

900 } 

901 await self._publish_event(event) 

902 

903 async def _publish_event(self, event: Dict[str, Any]) -> None: 

904 """ 

905 Publish event to all subscribers. 

906 

907 Args: 

908 event: event dictionary 

909 """ 

910 for queue in self._event_subscribers: 

911 await queue.put(event)