Coverage for mcpgateway/transports/streamablehttp_transport.py: 67%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 16:31 +0100

1# -*- coding: utf-8 -*- 

2"""Streamable HTTP Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Keval Mahajan 

7 

8This module implements Streamable Http transport for MCP 

9 

10Key components include: 

11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions 

12- Configuration options for: 

13 1. stateful/stateless operation 

14 2. JSON response mode or SSE streams 

15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state 

16 

17""" 

18 

19import contextvars 

20import logging 

21import re 

22from collections import deque 

23from contextlib import AsyncExitStack, asynccontextmanager 

24from dataclasses import dataclass 

25from typing import List, Union 

26from uuid import uuid4 

27 

28from fastapi.security.utils import get_authorization_scheme_param 

29from mcp import types 

30from mcp.server.lowlevel import Server 

31from mcp.server.streamable_http import ( 

32 EventCallback, 

33 EventId, 

34 EventMessage, 

35 EventStore, 

36 StreamId, 

37) 

38from mcp.server.streamable_http_manager import StreamableHTTPSessionManager 

39from mcp.types import JSONRPCMessage 

40from starlette.datastructures import Headers 

41from starlette.responses import JSONResponse 

42from starlette.status import HTTP_401_UNAUTHORIZED 

43from starlette.types import Receive, Scope, Send 

44 

45from mcpgateway.config import settings 

46from mcpgateway.db import SessionLocal 

47from mcpgateway.services.tool_service import ToolService 

48from mcpgateway.utils.verify_credentials import verify_credentials 

49 

50logger = logging.getLogger(__name__) 

51logging.basicConfig(level=logging.INFO) 

52 

53# Initialize ToolService and MCP Server 

54tool_service = ToolService() 

55mcp_app = Server("mcp-streamable-http-stateless") 

56 

57server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None) 

58 

59## ------------------------------ Event store ------------------------------ 

60 

61 

62@dataclass 

63class EventEntry: 

64 """ 

65 Represents an event entry in the event store. 

66 """ 

67 

68 event_id: EventId 

69 stream_id: StreamId 

70 message: JSONRPCMessage 

71 

72 

73class InMemoryEventStore(EventStore): 

74 """ 

75 Simple in-memory implementation of the EventStore interface for resumability. 

76 This is primarily intended for examples and testing, not for production use 

77 where a persistent storage solution would be more appropriate. 

78 

79 This implementation keeps only the last N events per stream for memory efficiency. 

80 """ 

81 

82 def __init__(self, max_events_per_stream: int = 100): 

83 """Initialize the event store. 

84 

85 Args: 

86 max_events_per_stream: Maximum number of events to keep per stream 

87 """ 

88 self.max_events_per_stream = max_events_per_stream 

89 # for maintaining last N events per stream 

90 self.streams: dict[StreamId, deque[EventEntry]] = {} 

91 # event_id -> EventEntry for quick lookup 

92 self.event_index: dict[EventId, EventEntry] = {} 

93 

94 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: 

95 """Stores an event with a generated event ID.""" 

96 event_id = str(uuid4()) 

97 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) 

98 

99 # Get or create deque for this stream 

100 if stream_id not in self.streams: 

101 self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) 

102 

103 # If deque is full, the oldest event will be automatically removed 

104 # We need to remove it from the event_index as well 

105 if len(self.streams[stream_id]) == self.max_events_per_stream: 

106 oldest_event = self.streams[stream_id][0] 

107 self.event_index.pop(oldest_event.event_id, None) 

108 

109 # Add new event 

110 self.streams[stream_id].append(event_entry) 

111 self.event_index[event_id] = event_entry 

112 

113 return event_id 

114 

115 async def replay_events_after( 

116 self, 

117 last_event_id: EventId, 

118 send_callback: EventCallback, 

119 ) -> StreamId | None: 

120 """Replays events that occurred after the specified event ID.""" 

121 if last_event_id not in self.event_index: 

122 logger.warning(f"Event ID {last_event_id} not found in store") 

123 return None 

124 

125 # Get the stream and find events after the last one 

126 last_event = self.event_index[last_event_id] 

127 stream_id = last_event.stream_id 

128 stream_events = self.streams.get(last_event.stream_id, deque()) 

129 

130 # Events in deque are already in chronological order 

131 found_last = False 

132 for event in stream_events: 

133 if found_last: 

134 await send_callback(EventMessage(event.message, event.event_id)) 

135 elif event.event_id == last_event_id: 135 ↛ 132line 135 didn't jump to line 132 because the condition on line 135 was always true

136 found_last = True 

137 

138 return stream_id 

139 

140 

141## ------------------------------ Streamable HTTP Transport ------------------------------ 

142 

143 

144@asynccontextmanager 

145async def get_db(): 

146 """ 

147 Asynchronous context manager for database sessions. 

148 

149 Yields: 

150 A database session instance from SessionLocal. 

151 Ensures the session is closed after use. 

152 """ 

153 db = SessionLocal() 

154 try: 

155 yield db 

156 finally: 

157 db.close() 

158 

159 

160@mcp_app.call_tool() 

161async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

162 """ 

163 Handles tool invocation via the MCP Server. 

164 

165 Args: 

166 name (str): The name of the tool to invoke. 

167 arguments (dict): A dictionary of arguments to pass to the tool. 

168 

169 Returns: 

170 List of content (TextContent, ImageContent, or EmbeddedResource) from the tool response. 

171 Logs and returns an empty list on failure. 

172 """ 

173 try: 

174 async with get_db() as db: 

175 result = await tool_service.invoke_tool(db, name, arguments) 

176 if not result or not result.content: 

177 logger.warning(f"No content returned by tool: {name}") 

178 return [] 

179 

180 return [types.TextContent(type=result.content[0].type, text=result.content[0].text)] 

181 except Exception as e: 

182 logger.exception(f"Error calling tool '{name}': {e}") 

183 return [] 

184 

185 

186@mcp_app.list_tools() 

187async def list_tools() -> List[types.Tool]: 

188 """ 

189 Lists all tools available to the MCP Server. 

190 

191 Returns: 

192 A list of Tool objects containing metadata such as name, description, and input schema. 

193 Logs and returns an empty list on failure. 

194 """ 

195 server_id = server_id_var.get() 

196 

197 if server_id: 

198 try: 

199 async with get_db() as db: 

200 tools = await tool_service.list_server_tools(db, server_id) 

201 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools] 

202 except Exception as e: 

203 logger.exception(f"Error listing tools:{e}") 

204 return [] 

205 else: 

206 try: 

207 async with get_db() as db: 

208 tools = await tool_service.list_tools(db) 

209 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools] 

210 except Exception as e: 

211 logger.exception(f"Error listing tools:{e}") 

212 return [] 

213 

214 

215class SessionManagerWrapper: 

216 """ 

217 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance. 

218 Provides start, stop, and request handling methods. 

219 """ 

220 

221 def __init__(self) -> None: 

222 """ 

223 Initializes the session manager and the exit stack used for managing its lifecycle. 

224 """ 

225 

226 if settings.use_stateful_sessions: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true

227 event_store = InMemoryEventStore() 

228 stateless = False 

229 else: 

230 event_store = None 

231 stateless = True 

232 

233 self.session_manager = StreamableHTTPSessionManager( 

234 app=mcp_app, 

235 event_store=event_store, 

236 json_response=settings.json_response_enabled, 

237 stateless=stateless, 

238 ) 

239 self.stack = AsyncExitStack() 

240 

241 async def initialize(self) -> None: 

242 """ 

243 Starts the Streamable HTTP session manager context. 

244 """ 

245 logger.info("Initializing Streamable HTTP service") 

246 await self.stack.enter_async_context(self.session_manager.run()) 

247 

248 async def shutdown(self) -> None: 

249 """ 

250 Gracefully shuts down the Streamable HTTP session manager. 

251 """ 

252 logger.info("Stopping Streamable HTTP Session Manager...") 

253 await self.stack.aclose() 

254 

255 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None: 

256 """ 

257 Forwards an incoming ASGI request to the streamable HTTP session manager. 

258 

259 Args: 

260 scope (Scope): ASGI scope object containing connection information. 

261 receive (Receive): ASGI receive callable. 

262 send (Send): ASGI send callable. 

263 Logs any exceptions that occur during request handling. 

264 """ 

265 

266 path = scope["modified_path"] 

267 match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path) 

268 

269 if match: 

270 server_id = match.group("server_id") 

271 server_id_var.set(server_id) 

272 

273 try: 

274 await self.session_manager.handle_request(scope, receive, send) 

275 except Exception as e: 

276 logger.exception(f"Error handling streamable HTTP request: {e}") 

277 raise 

278 

279 

280## ------------------------- Authentication for /mcp routes ------------------------------ 

281 

282 

283async def streamable_http_auth(scope, receive, send): 

284 """ 

285 Perform authentication check in middleware context (ASGI scope). 

286 

287 If path does not end with "/mcp", just continue (return True). 

288 

289 Only check Authorization header for Bearer token. 

290 If no Bearer token provided, allow (return True). 

291 

292 If auth_required is True and Bearer token provided, verify it. 

293 If verification fails, send 401 JSONResponse and return False. 

294 """ 

295 

296 path = scope.get("path", "") 

297 if not path.endswith("/mcp") and not path.endswith("/mcp/"): 

298 # No auth needed for other paths in this middleware usage 

299 return True 

300 

301 headers = Headers(scope=scope) 

302 authorization = headers.get("authorization") 

303 

304 token = None 

305 if authorization: 305 ↛ 309line 305 didn't jump to line 309 because the condition on line 305 was always true

306 scheme, credentials = get_authorization_scheme_param(authorization) 

307 if scheme.lower() == "bearer" and credentials: 307 ↛ 309line 307 didn't jump to line 309 because the condition on line 307 was always true

308 token = credentials 

309 try: 

310 await verify_credentials(token) 

311 except Exception: 

312 response = JSONResponse( 

313 {"detail": "Authentication failed"}, 

314 status_code=HTTP_401_UNAUTHORIZED, 

315 headers={"WWW-Authenticate": "Bearer"}, 

316 ) 

317 await response(scope, receive, send) 

318 return False 

319 

320 return True