Coverage for mcpgateway/transports/sse_transport.py: 78%
77 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
1# -*- coding: utf-8 -*-
2"""SSE Transport Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements Server-Sent Events (SSE) transport for MCP,
9providing server-to-client streaming with proper session management.
10"""
12import asyncio
13import json
14import logging
15import uuid
16from datetime import datetime
17from typing import Any, AsyncGenerator, Dict
19from fastapi import Request
20from sse_starlette.sse import EventSourceResponse
22from mcpgateway.config import settings
23from mcpgateway.transports.base import Transport
25logger = logging.getLogger(__name__)
28class SSETransport(Transport):
29 """Transport implementation using Server-Sent Events with proper session management."""
31 def __init__(self, base_url: str = None):
32 """Initialize SSE transport.
34 Args:
35 base_url: Base URL for client message endpoints
36 """
37 self._base_url = base_url or f"http://{settings.host}:{settings.port}"
38 self._connected = False
39 self._message_queue = asyncio.Queue()
40 self._client_gone = asyncio.Event()
41 self._session_id = str(uuid.uuid4())
43 logger.info(f"Creating SSE transport with base_url={self._base_url}, session_id={self._session_id}")
45 async def connect(self) -> None:
46 """Set up SSE connection."""
47 self._connected = True
48 logger.info(f"SSE transport connected: {self._session_id}")
50 async def disconnect(self) -> None:
51 """Clean up SSE connection."""
52 if self._connected: 52 ↛ exitline 52 didn't return from function 'disconnect' because the condition on line 52 was always true
53 self._connected = False
54 self._client_gone.set()
55 logger.info(f"SSE transport disconnected: {self._session_id}")
57 async def send_message(self, message: Dict[str, Any]) -> None:
58 """Send a message over SSE.
60 Args:
61 message: Message to send
63 Raises:
64 RuntimeError: If transport is not connected
65 Exception: If unable to put message to queue
66 """
67 if not self._connected:
68 raise RuntimeError("Transport not connected")
70 try:
71 await self._message_queue.put(message)
72 logger.debug(f"Message queued for SSE: {self._session_id}, method={message.get('method', '(response)')}")
73 except Exception as e:
74 logger.error(f"Failed to queue message: {e}")
75 raise
77 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
78 """Receive messages from the client over SSE transport.
80 This method implements a continuous message-receiving pattern for SSE transport.
81 Since SSE is primarily a server-to-client communication channel, this method
82 yields an initial initialize placeholder message and then enters a waiting loop.
83 The actual client messages are received via a separate HTTP POST endpoint
84 (not handled in this method).
86 The method will continue running until either:
87 1. The connection is explicitly disconnected (client_gone event is set)
88 2. The receive loop is cancelled from outside
90 Yields:
91 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always
92 an initialize placeholder with the format:
93 {"jsonrpc": "2.0", "method": "initialize", "id": 1}
95 Raises:
96 RuntimeError: If the transport is not connected when this method is called
97 asyncio.CancelledError: When the SSE receive loop is cancelled externally
98 """
99 if not self._connected: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 raise RuntimeError("Transport not connected")
102 # For SSE, we set up a loop to wait for messages which are delivered via POST
103 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder
104 # to keep the receive loop running
105 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1}
107 # Continue waiting for cancellation
108 try:
109 while not self._client_gone.is_set(): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 await asyncio.sleep(1.0)
111 except asyncio.CancelledError:
112 logger.info(f"SSE receive loop cancelled for session {self._session_id}")
113 raise
114 finally:
115 logger.info(f"SSE receive loop ended for session {self._session_id}")
117 async def is_connected(self) -> bool:
118 """Check if transport is connected.
120 Returns:
121 True if connected
122 """
123 return self._connected
125 async def create_sse_response(self, _request: Request) -> EventSourceResponse:
126 """Create SSE response for streaming.
128 Args:
129 _request: FastAPI request
131 Returns:
132 SSE response object
133 """
134 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}"
136 async def event_generator():
137 """Generate SSE events.
139 Yields:
140 SSE event
141 """
142 # Send the endpoint event first
143 yield {
144 "event": "endpoint",
145 "data": endpoint_url,
146 "retry": settings.sse_retry_timeout,
147 }
149 # Send keepalive immediately to help establish connection
150 yield {
151 "event": "keepalive",
152 "data": "{}",
153 "retry": settings.sse_retry_timeout,
154 }
156 try:
157 while not self._client_gone.is_set(): 157 ↛ 194line 157 didn't jump to line 194 because the condition on line 157 was always true
158 try:
159 # Wait for messages with a timeout for keepalives
160 message = await asyncio.wait_for(
161 self._message_queue.get(),
162 timeout=30.0, # 30 second timeout for keepalives (some tools require more timeout for execution)
163 )
165 data = json.dumps(message, default=lambda obj: (obj.strftime("%Y-%m-%d %H:%M:%S") if isinstance(obj, datetime) else TypeError("Type not serializable")))
167 # logger.info(f"Sending SSE message: {data[:100]}...")
168 logger.debug(f"Sending SSE message: {data}")
170 yield {
171 "event": "message",
172 "data": data,
173 "retry": settings.sse_retry_timeout,
174 }
175 except asyncio.TimeoutError:
176 # Send keepalive on timeout
177 yield {
178 "event": "keepalive",
179 "data": "{}",
180 "retry": settings.sse_retry_timeout,
181 }
182 except Exception as e:
183 logger.error(f"Error processing SSE message: {e}")
184 yield {
185 "event": "error",
186 "data": json.dumps({"error": str(e)}),
187 "retry": settings.sse_retry_timeout,
188 }
189 except asyncio.CancelledError:
190 logger.info(f"SSE event generator cancelled: {self._session_id}")
191 except Exception as e:
192 logger.error(f"SSE event generator error: {e}")
193 finally:
194 logger.info(f"SSE event generator completed: {self._session_id}")
195 # We intentionally don't set client_gone here to allow queued messages to be processed
197 return EventSourceResponse(
198 event_generator(),
199 status_code=200,
200 headers={
201 "Cache-Control": "no-cache",
202 "Connection": "keep-alive",
203 "Content-Type": "text/event-stream",
204 "X-MCP-SSE": "true",
205 },
206 )
208 async def _client_disconnected(self, _request: Request) -> bool:
209 """Check if client has disconnected.
211 Args:
212 _request: FastAPI Request object
214 Returns:
215 bool: True if client disconnected
216 """
217 # We only check our internal client_gone flag
218 # We intentionally don't check connection_lost on the request
219 # as it can be unreliable and cause premature closures
220 return self._client_gone.is_set()
222 @property
223 def session_id(self) -> str:
224 """
225 Get the session ID for this transport.
227 Returns:
228 str: session_id
229 """
230 return self._session_id