Coverage for mcpgateway/transports/sse_transport.py: 78%

77 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 12:53 +0100

1# -*- coding: utf-8 -*- 

2"""SSE Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements Server-Sent Events (SSE) transport for MCP, 

9providing server-to-client streaming with proper session management. 

10""" 

11 

12import asyncio 

13import json 

14import logging 

15import uuid 

16from datetime import datetime 

17from typing import Any, AsyncGenerator, Dict 

18 

19from fastapi import Request 

20from sse_starlette.sse import EventSourceResponse 

21 

22from mcpgateway.config import settings 

23from mcpgateway.transports.base import Transport 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28class SSETransport(Transport): 

29 """Transport implementation using Server-Sent Events with proper session management.""" 

30 

31 def __init__(self, base_url: str = None): 

32 """Initialize SSE transport. 

33 

34 Args: 

35 base_url: Base URL for client message endpoints 

36 """ 

37 self._base_url = base_url or f"http://{settings.host}:{settings.port}" 

38 self._connected = False 

39 self._message_queue = asyncio.Queue() 

40 self._client_gone = asyncio.Event() 

41 self._session_id = str(uuid.uuid4()) 

42 

43 logger.info(f"Creating SSE transport with base_url={self._base_url}, session_id={self._session_id}") 

44 

45 async def connect(self) -> None: 

46 """Set up SSE connection.""" 

47 self._connected = True 

48 logger.info(f"SSE transport connected: {self._session_id}") 

49 

50 async def disconnect(self) -> None: 

51 """Clean up SSE connection.""" 

52 if self._connected: 52 ↛ exitline 52 didn't return from function 'disconnect' because the condition on line 52 was always true

53 self._connected = False 

54 self._client_gone.set() 

55 logger.info(f"SSE transport disconnected: {self._session_id}") 

56 

57 async def send_message(self, message: Dict[str, Any]) -> None: 

58 """Send a message over SSE. 

59 

60 Args: 

61 message: Message to send 

62 

63 Raises: 

64 RuntimeError: If transport is not connected 

65 Exception: If unable to put message to queue 

66 """ 

67 if not self._connected: 

68 raise RuntimeError("Transport not connected") 

69 

70 try: 

71 await self._message_queue.put(message) 

72 logger.debug(f"Message queued for SSE: {self._session_id}, method={message.get('method', '(response)')}") 

73 except Exception as e: 

74 logger.error(f"Failed to queue message: {e}") 

75 raise 

76 

77 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

78 """Receive messages from the client over SSE transport. 

79 

80 This method implements a continuous message-receiving pattern for SSE transport. 

81 Since SSE is primarily a server-to-client communication channel, this method 

82 yields an initial initialize placeholder message and then enters a waiting loop. 

83 The actual client messages are received via a separate HTTP POST endpoint 

84 (not handled in this method). 

85 

86 The method will continue running until either: 

87 1. The connection is explicitly disconnected (client_gone event is set) 

88 2. The receive loop is cancelled from outside 

89 

90 Yields: 

91 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always 

92 an initialize placeholder with the format: 

93 {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

94 

95 Raises: 

96 RuntimeError: If the transport is not connected when this method is called 

97 asyncio.CancelledError: When the SSE receive loop is cancelled externally 

98 """ 

99 if not self._connected: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 raise RuntimeError("Transport not connected") 

101 

102 # For SSE, we set up a loop to wait for messages which are delivered via POST 

103 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder 

104 # to keep the receive loop running 

105 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

106 

107 # Continue waiting for cancellation 

108 try: 

109 while not self._client_gone.is_set(): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 await asyncio.sleep(1.0) 

111 except asyncio.CancelledError: 

112 logger.info(f"SSE receive loop cancelled for session {self._session_id}") 

113 raise 

114 finally: 

115 logger.info(f"SSE receive loop ended for session {self._session_id}") 

116 

117 async def is_connected(self) -> bool: 

118 """Check if transport is connected. 

119 

120 Returns: 

121 True if connected 

122 """ 

123 return self._connected 

124 

125 async def create_sse_response(self, _request: Request) -> EventSourceResponse: 

126 """Create SSE response for streaming. 

127 

128 Args: 

129 _request: FastAPI request 

130 

131 Returns: 

132 SSE response object 

133 """ 

134 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}" 

135 

136 async def event_generator(): 

137 """Generate SSE events. 

138 

139 Yields: 

140 SSE event 

141 """ 

142 # Send the endpoint event first 

143 yield { 

144 "event": "endpoint", 

145 "data": endpoint_url, 

146 "retry": settings.sse_retry_timeout, 

147 } 

148 

149 # Send keepalive immediately to help establish connection 

150 yield { 

151 "event": "keepalive", 

152 "data": "{}", 

153 "retry": settings.sse_retry_timeout, 

154 } 

155 

156 try: 

157 while not self._client_gone.is_set(): 157 ↛ 194line 157 didn't jump to line 194 because the condition on line 157 was always true

158 try: 

159 # Wait for messages with a timeout for keepalives 

160 message = await asyncio.wait_for( 

161 self._message_queue.get(), 

162 timeout=30.0, # 30 second timeout for keepalives (some tools require more timeout for execution) 

163 ) 

164 

165 data = json.dumps(message, default=lambda obj: (obj.strftime("%Y-%m-%d %H:%M:%S") if isinstance(obj, datetime) else TypeError("Type not serializable"))) 

166 

167 # logger.info(f"Sending SSE message: {data[:100]}...") 

168 logger.debug(f"Sending SSE message: {data}") 

169 

170 yield { 

171 "event": "message", 

172 "data": data, 

173 "retry": settings.sse_retry_timeout, 

174 } 

175 except asyncio.TimeoutError: 

176 # Send keepalive on timeout 

177 yield { 

178 "event": "keepalive", 

179 "data": "{}", 

180 "retry": settings.sse_retry_timeout, 

181 } 

182 except Exception as e: 

183 logger.error(f"Error processing SSE message: {e}") 

184 yield { 

185 "event": "error", 

186 "data": json.dumps({"error": str(e)}), 

187 "retry": settings.sse_retry_timeout, 

188 } 

189 except asyncio.CancelledError: 

190 logger.info(f"SSE event generator cancelled: {self._session_id}") 

191 except Exception as e: 

192 logger.error(f"SSE event generator error: {e}") 

193 finally: 

194 logger.info(f"SSE event generator completed: {self._session_id}") 

195 # We intentionally don't set client_gone here to allow queued messages to be processed 

196 

197 return EventSourceResponse( 

198 event_generator(), 

199 status_code=200, 

200 headers={ 

201 "Cache-Control": "no-cache", 

202 "Connection": "keep-alive", 

203 "Content-Type": "text/event-stream", 

204 "X-MCP-SSE": "true", 

205 }, 

206 ) 

207 

208 async def _client_disconnected(self, _request: Request) -> bool: 

209 """Check if client has disconnected. 

210 

211 Args: 

212 _request: FastAPI Request object 

213 

214 Returns: 

215 bool: True if client disconnected 

216 """ 

217 # We only check our internal client_gone flag 

218 # We intentionally don't check connection_lost on the request 

219 # as it can be unreliable and cause premature closures 

220 return self._client_gone.is_set() 

221 

222 @property 

223 def session_id(self) -> str: 

224 """ 

225 Get the session ID for this transport. 

226 

227 Returns: 

228 str: session_id 

229 """ 

230 return self._session_id