Coverage for mcpgateway/services/gateway_service.py: 46%
389 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:47 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:47 +0100
1# -*- coding: utf-8 -*-
2"""Gateway Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements gateway federation according to the MCP specification.
9It handles:
10- Gateway discovery and registration
11- Request forwarding
12- Capability aggregation
13- Health monitoring
14- Active/inactive gateway management
15"""
17import asyncio
18import logging
19import uuid
20from datetime import datetime, timezone
21from typing import Any, AsyncGenerator, Dict, List, Optional, Set
23import httpx
24from filelock import FileLock, Timeout
25from mcp import ClientSession
26from mcp.client.sse import sse_client
27from mcp.client.streamable_http import streamablehttp_client
28from sqlalchemy import select
29from sqlalchemy.orm import Session
31from mcpgateway.config import settings
32from mcpgateway.db import Gateway as DbGateway
33from mcpgateway.db import SessionLocal
34from mcpgateway.db import Tool as DbTool
35from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, ToolCreate
36from mcpgateway.services.tool_service import ToolService
37from mcpgateway.utils.services_auth import decode_auth
39try:
40 import redis
42 REDIS_AVAILABLE = True
43except ImportError:
44 REDIS_AVAILABLE = False
45 logging.info("Redis is not utilized in this environment.")
47# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
48logger = logging.getLogger(__name__)
51GW_FAILURE_THRESHOLD = settings.unhealthy_threshold
52GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval
55class GatewayError(Exception):
56 """Base class for gateway-related errors."""
59class GatewayNotFoundError(GatewayError):
60 """Raised when a requested gateway is not found."""
63class GatewayNameConflictError(GatewayError):
64 """Raised when a gateway name conflicts with existing (active or inactive) gateway."""
66 def __init__(self, name: str, is_active: bool = True, gateway_id: Optional[int] = None):
67 """Initialize the error with gateway information.
69 Args:
70 name: The conflicting gateway name
71 is_active: Whether the existing gateway is active
72 gateway_id: ID of the existing gateway if available
73 """
74 self.name = name
75 self.is_active = is_active
76 self.gateway_id = gateway_id
77 message = f"Gateway already exists with name: {name}"
78 if not is_active: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 message += f" (currently inactive, ID: {gateway_id})"
80 super().__init__(message)
83class GatewayConnectionError(GatewayError):
84 """Raised when gateway connection fails."""
87class GatewayService:
88 """Service for managing federated gateways.
90 Handles:
91 - Gateway registration and health checks
92 - Request forwarding
93 - Capability negotiation
94 - Federation events
95 - Active/inactive status management
96 """
98 def __init__(self):
99 """Initialize the gateway service."""
100 self._event_subscribers: List[asyncio.Queue] = []
101 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
102 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL
103 self._health_check_task: Optional[asyncio.Task] = None
104 self._active_gateways: Set[str] = set() # Track active gateway URLs
105 self._stream_response = None
106 self._pending_responses = {}
107 self.tool_service = ToolService()
108 self._gateway_failure_counts: dict[str, int] = {}
110 # For health checks, we determine the leader instance.
111 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
113 if self.redis_url and REDIS_AVAILABLE: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 self._redis_client = redis.from_url(self.redis_url)
115 self._instance_id = str(uuid.uuid4()) # Unique ID for this process
116 self._leader_key = "gateway_service_leader"
117 self._leader_ttl = 40 # seconds
118 elif settings.cache_type != "none": 118 ↛ 124line 118 didn't jump to line 124 because the condition on line 118 was always true
119 # Fallback: File-based lock
120 self._redis_client = None
121 self._lock_path = settings.filelock_path
122 self._file_lock = FileLock(self._lock_path)
123 else:
124 self._redis_client = None
126 async def initialize(self) -> None:
127 """Initialize the service and start health check if this instance is the leader.
129 Raises:
130 ConnectionError: When redis ping fails
131 """
132 logger.info("Initializing gateway service")
134 if self._redis_client:
135 # Check if Redis is available
136 pong = self._redis_client.ping()
137 if not pong:
138 raise ConnectionError("Redis ping failed.")
140 is_leader = self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
141 if is_leader:
142 logger.info("Acquired Redis leadership. Starting health check task.")
143 self._health_check_task = asyncio.create_task(self._run_health_checks())
144 else:
145 # Always create the health check task in filelock mode; leader check is handled inside.
146 self._health_check_task = asyncio.create_task(self._run_health_checks())
148 async def shutdown(self) -> None:
149 """Shutdown the service."""
150 if self._health_check_task:
151 self._health_check_task.cancel()
152 try:
153 await self._health_check_task
154 except asyncio.CancelledError:
155 pass
157 await self._http_client.aclose()
158 self._event_subscribers.clear()
159 self._active_gateways.clear()
160 logger.info("Gateway service shutdown complete")
162 async def register_gateway(self, db: Session, gateway: GatewayCreate) -> GatewayRead:
163 """Register a new gateway.
165 Args:
166 db: Database session
167 gateway: Gateway creation schema
169 Returns:
170 Created gateway information
172 Raises:
173 GatewayNameConflictError: If gateway name already exists
174 []: When ExceptionGroup found
175 """
176 try:
177 # Check for name conflicts (both active and inactive)
178 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
180 if existing_gateway:
181 raise GatewayNameConflictError(
182 gateway.name,
183 is_active=existing_gateway.is_active,
184 gateway_id=existing_gateway.id,
185 )
187 auth_type = getattr(gateway, "auth_type", None)
188 auth_value = getattr(gateway, "auth_value", {})
190 capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value, gateway.transport)
192 all_names = [td.name for td in tools]
194 existing_tools = db.execute(select(DbTool).where(DbTool.name.in_(all_names))).scalars().all()
195 existing_tool_names = [tool.name for tool in existing_tools]
197 tools = [
198 DbTool(
199 name=tool.name,
200 url=str(gateway.url),
201 description=tool.description,
202 integration_type=tool.integration_type,
203 request_type=tool.request_type,
204 headers=tool.headers,
205 input_schema=tool.input_schema,
206 jsonpath_filter=tool.jsonpath_filter,
207 auth_type=auth_type,
208 auth_value=auth_value,
209 )
210 for tool in tools
211 ]
213 existing_tools = [tool for tool in tools if tool.name in existing_tool_names]
214 new_tools = [tool for tool in tools if tool.name not in existing_tool_names]
216 # Create DB model
217 db_gateway = DbGateway(
218 name=gateway.name,
219 url=str(gateway.url),
220 description=gateway.description,
221 transport=gateway.transport,
222 capabilities=capabilities,
223 last_seen=datetime.now(timezone.utc),
224 auth_type=auth_type,
225 auth_value=auth_value,
226 tools=new_tools,
227 # federated_tools=existing_tools + new_tools
228 )
230 # Add to DB
231 db.add(db_gateway)
232 db.commit()
233 db.refresh(db_gateway)
235 # Update tracking
236 self._active_gateways.add(db_gateway.url)
238 # Notify subscribers
239 await self._notify_gateway_added(db_gateway)
241 return GatewayRead.model_validate(gateway)
242 except* GatewayConnectionError as ge:
243 logger.error("GatewayConnectionError in group: %s", ge.exceptions)
244 raise ge.exceptions[0]
245 except* ValueError as ve:
246 logger.error("ValueErrors in group: %s", ve.exceptions)
247 raise ve.exceptions[0]
248 except* RuntimeError as re:
249 logger.error("RuntimeErrors in group: %s", re.exceptions)
250 raise re.exceptions[0]
251 except* BaseException as other: # catches every other sub-exception
252 logger.error("Other grouped errors: %s", other.exceptions)
253 raise other.exceptions[0]
255 async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
256 """List all registered gateways.
258 Args:
259 db: Database session
260 include_inactive: Whether to include inactive gateways
262 Returns:
263 List of registered gateways
264 """
265 query = select(DbGateway)
267 if not include_inactive: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true
268 query = query.where(DbGateway.is_active)
270 gateways = db.execute(query).scalars().all()
271 return [GatewayRead.model_validate(g) for g in gateways]
273 async def update_gateway(self, db: Session, gateway_id: int, gateway_update: GatewayUpdate) -> GatewayRead:
274 """Update a gateway.
276 Args:
277 db: Database session
278 gateway_id: Gateway ID to update
279 gateway_update: Updated gateway data
281 Returns:
282 Updated gateway information
284 Raises:
285 GatewayNotFoundError: If gateway not found
286 GatewayError: For other update errors
287 GatewayNameConflictError: If gateway name conflict occurs
288 """
289 try:
290 # Find gateway
291 gateway = db.get(DbGateway, gateway_id)
292 if not gateway:
293 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
295 if not gateway.is_active: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true
296 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive")
298 # Check for name conflicts if name is being changed
299 if gateway_update.name is not None and gateway_update.name != gateway.name: 299 ↛ 310line 299 didn't jump to line 310 because the condition on line 299 was always true
300 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none()
302 if existing_gateway:
303 raise GatewayNameConflictError(
304 gateway_update.name,
305 is_active=existing_gateway.is_active,
306 gateway_id=existing_gateway.id,
307 )
309 # Update fields if provided
310 if gateway_update.name is not None: 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true
311 gateway.name = gateway_update.name
312 if gateway_update.url is not None: 312 ↛ 314line 312 didn't jump to line 314 because the condition on line 312 was always true
313 gateway.url = str(gateway_update.url)
314 if gateway_update.description is not None: 314 ↛ 316line 314 didn't jump to line 316 because the condition on line 314 was always true
315 gateway.description = gateway_update.description
316 if gateway_update.transport is not None: 316 ↛ 319line 316 didn't jump to line 319 because the condition on line 316 was always true
317 gateway.transport = gateway_update.transport
319 if getattr(gateway, "auth_type", None) is not None: 319 ↛ 327line 319 didn't jump to line 327 because the condition on line 319 was always true
320 gateway.auth_type = gateway_update.auth_type
322 # if auth_type is not None and only then check auth_value
323 if getattr(gateway, "auth_value", {}) != {}: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true
324 gateway.auth_value = gateway_update.auth_value
326 # Try to reinitialize connection if URL changed
327 if gateway_update.url is not None: 327 ↛ 339line 327 didn't jump to line 339 because the condition on line 327 was always true
328 try:
329 capabilities, _ = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
330 gateway.capabilities = capabilities
331 gateway.last_seen = datetime.utcnow()
333 # Update tracking with new URL
334 self._active_gateways.discard(gateway.url)
335 self._active_gateways.add(gateway.url)
336 except Exception as e:
337 logger.warning(f"Failed to initialize updated gateway: {e}")
339 gateway.updated_at = datetime.utcnow()
340 db.commit()
341 db.refresh(gateway)
343 # Notify subscribers
344 await self._notify_gateway_updated(gateway)
346 logger.info(f"Updated gateway: {gateway.name}")
347 return GatewayRead.model_validate(gateway)
349 except Exception as e:
350 db.rollback()
351 raise GatewayError(f"Failed to update gateway: {str(e)}")
353 async def get_gateway(self, db: Session, gateway_id: int, include_inactive: bool = False) -> GatewayRead:
354 """Get a specific gateway by ID.
356 Args:
357 db: Database session
358 gateway_id: Gateway ID
359 include_inactive: Whether to include inactive gateways
361 Returns:
362 Gateway information
364 Raises:
365 GatewayNotFoundError: If gateway not found
366 """
367 gateway = db.get(DbGateway, gateway_id)
368 if not gateway:
369 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
371 if not gateway.is_active and not include_inactive:
372 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive")
374 return GatewayRead.model_validate(gateway)
376 async def toggle_gateway_status(self, db: Session, gateway_id: int, activate: bool) -> GatewayRead:
377 """Toggle gateway active status.
379 Args:
380 db: Database session
381 gateway_id: Gateway ID to toggle
382 activate: True to activate, False to deactivate
384 Returns:
385 Updated gateway information
387 Raises:
388 GatewayNotFoundError: If gateway not found
389 GatewayError: For other errors
390 """
391 try:
392 gateway = db.get(DbGateway, gateway_id)
393 if not gateway: 393 ↛ 394line 393 didn't jump to line 394 because the condition on line 393 was never true
394 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
396 # Update status if it's different
397 if gateway.is_active != activate: 397 ↛ 429line 397 didn't jump to line 429 because the condition on line 397 was always true
398 gateway.is_active = activate
399 gateway.updated_at = datetime.utcnow()
401 # Update tracking
402 if activate: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true
403 self._active_gateways.add(gateway.url)
404 # Try to initialize if activating
405 try:
406 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
407 gateway.capabilities = capabilities.dict()
408 gateway.last_seen = datetime.utcnow()
409 except Exception as e:
410 logger.warning(f"Failed to initialize reactivated gateway: {e}")
411 else:
412 self._active_gateways.discard(gateway.url)
414 db.commit()
415 db.refresh(gateway)
417 tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all()
418 for tool in tools:
419 await self.tool_service.toggle_tool_status(db, tool.id, activate)
421 # Notify subscribers
422 if activate: 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true
423 await self._notify_gateway_activated(gateway)
424 else:
425 await self._notify_gateway_deactivated(gateway)
427 logger.info(f"Gateway {gateway.name} {'activated' if activate else 'deactivated'}")
429 return GatewayRead.model_validate(gateway)
431 except Exception as e:
432 db.rollback()
433 raise GatewayError(f"Failed to toggle gateway status: {str(e)}")
435 async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
436 """
437 Notify subscribers of gateway update.
439 Args:
440 gateway: Gateway to update
441 """
442 event = {
443 "type": "gateway_updated",
444 "data": {
445 "id": gateway.id,
446 "name": gateway.name,
447 "url": gateway.url,
448 "description": gateway.description,
449 "is_active": gateway.is_active,
450 },
451 "timestamp": datetime.utcnow().isoformat(),
452 }
453 await self._publish_event(event)
455 async def delete_gateway(self, db: Session, gateway_id: int) -> None:
456 """Permanently delete a gateway.
458 Args:
459 db: Database session
460 gateway_id: Gateway ID to delete
462 Raises:
463 GatewayNotFoundError: If gateway not found
464 GatewayError: For other deletion errors
465 """
466 try:
467 # Find gateway
468 gateway = db.get(DbGateway, gateway_id)
469 if not gateway:
470 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
472 # Store gateway info for notification before deletion
473 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
475 # Hard delete gateway
476 db.delete(gateway)
477 db.commit()
479 # Update tracking
480 self._active_gateways.discard(gateway.url)
482 # Notify subscribers
483 await self._notify_gateway_deleted(gateway_info)
485 logger.info(f"Permanently deleted gateway: {gateway.name}")
487 except Exception as e:
488 db.rollback()
489 raise GatewayError(f"Failed to delete gateway: {str(e)}")
491 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
492 """Forward a request to a gateway.
494 Args:
495 gateway: Gateway to forward to
496 method: RPC method name
497 params: Optional method parameters
499 Returns:
500 Gateway response
502 Raises:
503 GatewayConnectionError: If forwarding fails
504 GatewayError: If gateway gave an error
505 """
506 if not gateway.is_active: 506 ↛ 507line 506 didn't jump to line 507 because the condition on line 506 was never true
507 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}")
509 try:
510 # Build RPC request
511 request = {"jsonrpc": "2.0", "id": 1, "method": method}
512 if params:
513 request["params"] = params
515 # Directly use the persistent HTTP client (no async with)
516 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers())
517 response.raise_for_status()
518 result = response.json()
520 # Update last seen timestamp
521 gateway.last_seen = datetime.utcnow()
523 if "error" in result:
524 raise GatewayError(f"Gateway error: {result['error'].get('message')}")
525 return result.get("result")
527 except Exception as e:
528 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
530 async def _handle_gateway_failure(self, gateway: str) -> None:
531 """
532 Tracks and handles gateway failures during health checks.
533 If the failure count exceeds the threshold, the gateway is deactivated.
535 Args:
536 gateway (str): The gateway object that failed its health check.
538 Returns:
539 None
540 """
541 if GW_FAILURE_THRESHOLD == -1:
542 return # Gateway failure action disabled
543 count = self._gateway_failure_counts.get(gateway.id, 0) + 1
544 self._gateway_failure_counts[gateway.id] = count
546 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).")
548 if count >= GW_FAILURE_THRESHOLD:
549 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...")
550 with SessionLocal() as db:
551 await self.toggle_gateway_status(db, gateway.id, False)
552 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
554 async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
555 """Health check for a list of gateways.
557 Deactivates gateway if gateway is not healthy.
559 Args:
560 gateways (List[DbGateway]): List of gateways to check if healthy
562 Returns:
563 bool: True if all active gateways are healthy
564 """
565 # Reuse a single HTTP client for all requests
566 async with httpx.AsyncClient() as client:
567 for gateway in gateways:
568 # Inactive gateways are unhealthy
569 if not gateway.is_active:
570 continue
572 try:
573 # Ensure auth_value is a dict
574 auth_data = gateway.auth_value or {}
575 headers = decode_auth(auth_data)
577 # Perform the GET and raise on 4xx/5xx
578 if (gateway.transport).lower() == "sse":
579 timeout = httpx.Timeout(settings.health_check_timeout)
580 async with client.stream("GET", gateway.url, headers=headers, timeout=timeout) as response:
581 # This will raise immediately if status is 4xx/5xx
582 response.raise_for_status()
583 elif (gateway.transport).lower() == "streamablehttp":
584 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, get_session_id):
585 async with ClientSession(read_stream, write_stream) as session:
586 # Initialize the session
587 response = await session.initialize()
589 # Mark successful check
590 gateway.last_seen = datetime.utcnow()
592 except Exception as e:
593 await self._handle_gateway_failure(gateway)
595 # All gateways passed
596 return True
598 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
599 """Aggregate capabilities from all gateways.
601 Args:
602 db: Database session
604 Returns:
605 Combined capabilities
606 """
607 capabilities = {
608 "prompts": {"listChanged": True},
609 "resources": {"subscribe": True, "listChanged": True},
610 "tools": {"listChanged": True},
611 "logging": {},
612 }
614 # Get all active gateways
615 gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
617 # Combine capabilities
618 for gateway in gateways:
619 if gateway.capabilities:
620 for key, value in gateway.capabilities.items():
621 if key not in capabilities:
622 capabilities[key] = value
623 elif isinstance(value, dict):
624 capabilities[key].update(value)
626 return capabilities
628 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
629 """Subscribe to gateway events.
631 Yields:
632 Gateway event messages
633 """
634 queue: asyncio.Queue = asyncio.Queue()
635 self._event_subscribers.append(queue)
636 try:
637 while True:
638 event = await queue.get()
639 yield event
640 finally:
641 self._event_subscribers.remove(queue)
643 async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE") -> Any:
644 """Initialize connection to a gateway and retrieve its capabilities.
646 Args:
647 url: Gateway URL
648 authentication: Optional authentication headers
650 Returns:
651 Capabilities dictionary as provided by the gateway.
653 Raises:
654 GatewayConnectionError: If initialization fails.
655 """
656 try:
657 if authentication is None:
658 authentication = {}
660 async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
661 """
662 Connect to an MCP server running with SSE transport
664 Args:
665 server_url: URL to connect to the server
666 authentication: Authentication headers for connection to URL
668 Returns:
669 list, list: List of capabilities and tools
670 """
671 if authentication is None:
672 authentication = {}
673 # Store the context managers so they stay alive
674 decoded_auth = decode_auth(authentication)
676 # Use async with for both sse_client and ClientSession
677 async with sse_client(url=server_url, headers=decoded_auth) as streams:
678 async with ClientSession(*streams) as session:
679 # Initialize the session
680 response = await session.initialize()
681 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
683 response = await session.list_tools()
684 tools = response.tools
685 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
686 tools = [ToolCreate.model_validate(tool) for tool in tools]
688 return capabilities, tools
690 async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
691 """
692 Connect to an MCP server running with Streamable HTTP transport
694 Args:
695 server_url: URL to connect to the server
696 authentication: Authentication headers for connection to URL
698 Returns:
699 list, list: List of capabilities and tools
700 """
701 if authentication is None:
702 authentication = {}
703 # Store the context managers so they stay alive
704 decoded_auth = decode_auth(authentication)
706 # Use async with for both streamablehttp_client and ClientSession
707 async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, get_session_id):
708 async with ClientSession(read_stream, write_stream) as session:
709 # Initialize the session
710 response = await session.initialize()
711 # if get_session_id:
712 # session_id = get_session_id()
713 # if session_id:
714 # print(f"Session ID: {session_id}")
715 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
716 response = await session.list_tools()
717 tools = response.tools
718 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
719 tools = [ToolCreate.model_validate(tool) for tool in tools]
720 for tool in tools:
721 tool.request_type = "STREAMABLEHTTP"
723 return capabilities, tools
725 if transport.lower() == "sse":
726 capabilities, tools = await connect_to_sse_server(url, authentication)
727 elif transport.lower() == "streamablehttp":
728 capabilities, tools = await connect_to_streamablehttp_server(url, authentication)
730 return capabilities, tools
731 except Exception as e:
732 raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
734 def _get_active_gateways(self) -> list[DbGateway]:
735 """Sync function for database operations (runs in thread).
737 Returns:
738 List[DbGateway]: List of active gateways
739 """
740 with SessionLocal() as db:
741 return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
743 async def _run_health_checks(self) -> None:
744 """Run health checks periodically,
745 Uses Redis or FileLock - for multiple workers.
746 Uses simple health check for single worker mode."""
748 while True:
749 try:
750 if self._redis_client and settings.cache_type == "redis":
751 # Redis-based leader check
752 current_leader = self._redis_client.get(self._leader_key)
753 if current_leader != self._instance_id.encode():
754 return
755 self._redis_client.expire(self._leader_key, self._leader_ttl)
757 # Run health checks
758 gateways = await asyncio.to_thread(self._get_active_gateways)
759 if gateways:
760 await self.check_health_of_gateways(gateways)
762 await asyncio.sleep(self._health_check_interval)
764 elif settings.cache_type == "none":
765 try:
766 # For single worker mode, run health checks directly
767 gateways = await asyncio.to_thread(self._get_active_gateways)
769 if gateways:
770 await self.check_health_of_gateways(gateways)
771 except Exception as e:
772 logger.error(f"Health check run failed: {str(e)}")
774 await asyncio.sleep(self._health_check_interval)
776 else:
777 # FileLock-based leader fallback
778 try:
779 self._file_lock.acquire(timeout=0)
780 logger.info("File lock acquired. Running health checks.")
782 while True:
783 gateways = await asyncio.to_thread(self._get_active_gateways)
784 if gateways:
785 await self.check_health_of_gateways(gateways)
786 await asyncio.sleep(self._health_check_interval)
788 except Timeout:
789 logger.debug("File lock already held. Retrying later.")
790 await asyncio.sleep(self._health_check_interval)
792 except Exception as e:
793 logger.error(f"FileLock health check failed: {str(e)}")
795 finally:
796 if self._file_lock.is_locked:
797 try:
798 self._file_lock.release()
799 logger.info("Released file lock.")
800 except Exception as e:
801 logger.warning(f"Failed to release file lock: {str(e)}")
803 except Exception as e:
804 logger.error(f"Unexpected error in health check loop: {str(e)}")
805 await asyncio.sleep(self._health_check_interval)
807 def _get_auth_headers(self) -> Dict[str, str]:
808 """
809 Get headers for gateway authentication.
811 Returns:
812 dict: Authorization header dict
813 """
814 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
815 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key, "Content-Type": "application/json"}
817 async def _notify_gateway_added(self, gateway: DbGateway) -> None:
818 """
819 Notify subscribers of gateway addition.
821 Args:
822 gateway: Gateway to add
823 """
824 event = {
825 "type": "gateway_added",
826 "data": {
827 "id": gateway.id,
828 "name": gateway.name,
829 "url": gateway.url,
830 "description": gateway.description,
831 "is_active": gateway.is_active,
832 },
833 "timestamp": datetime.utcnow().isoformat(),
834 }
835 await self._publish_event(event)
837 async def _notify_gateway_activated(self, gateway: DbGateway) -> None:
838 """
839 Notify subscribers of gateway activation.
841 Args:
842 gateway: Gateway to activate
843 """
844 event = {
845 "type": "gateway_activated",
846 "data": {
847 "id": gateway.id,
848 "name": gateway.name,
849 "url": gateway.url,
850 "is_active": True,
851 },
852 "timestamp": datetime.utcnow().isoformat(),
853 }
854 await self._publish_event(event)
856 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None:
857 """
858 Notify subscribers of gateway deactivation.
860 Args:
861 gateway: Gateway database object
862 """
863 event = {
864 "type": "gateway_deactivated",
865 "data": {
866 "id": gateway.id,
867 "name": gateway.name,
868 "url": gateway.url,
869 "is_active": False,
870 },
871 "timestamp": datetime.utcnow().isoformat(),
872 }
873 await self._publish_event(event)
875 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None:
876 """
877 Notify subscribers of gateway deletion.
879 Args:
880 gateway_info: Dict containing information about gateway to delete
881 """
882 event = {
883 "type": "gateway_deleted",
884 "data": gateway_info,
885 "timestamp": datetime.utcnow().isoformat(),
886 }
887 await self._publish_event(event)
889 async def _notify_gateway_removed(self, gateway: DbGateway) -> None:
890 """
891 Notify subscribers of gateway removal (deactivation).
893 Args:
894 gateway: Gateway to remove
895 """
896 event = {
897 "type": "gateway_removed",
898 "data": {"id": gateway.id, "name": gateway.name, "is_active": False},
899 "timestamp": datetime.utcnow().isoformat(),
900 }
901 await self._publish_event(event)
903 async def _publish_event(self, event: Dict[str, Any]) -> None:
904 """
905 Publish event to all subscribers.
907 Args:
908 event: event dictionary
909 """
910 for queue in self._event_subscribers:
911 await queue.put(event)