Coverage for mcpgateway/transports/streamablehttp_transport.py: 67%
142 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 16:31 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 16:31 +0100
1# -*- coding: utf-8 -*-
2"""Streamable HTTP Transport Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Keval Mahajan
8This module implements Streamable Http transport for MCP
10Key components include:
11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12- Configuration options for:
13 1. stateful/stateless operation
14 2. JSON response mode or SSE streams
15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17"""
19import contextvars
20import logging
21import re
22from collections import deque
23from contextlib import AsyncExitStack, asynccontextmanager
24from dataclasses import dataclass
25from typing import List, Union
26from uuid import uuid4
28from fastapi.security.utils import get_authorization_scheme_param
29from mcp import types
30from mcp.server.lowlevel import Server
31from mcp.server.streamable_http import (
32 EventCallback,
33 EventId,
34 EventMessage,
35 EventStore,
36 StreamId,
37)
38from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
39from mcp.types import JSONRPCMessage
40from starlette.datastructures import Headers
41from starlette.responses import JSONResponse
42from starlette.status import HTTP_401_UNAUTHORIZED
43from starlette.types import Receive, Scope, Send
45from mcpgateway.config import settings
46from mcpgateway.db import SessionLocal
47from mcpgateway.services.tool_service import ToolService
48from mcpgateway.utils.verify_credentials import verify_credentials
50logger = logging.getLogger(__name__)
51logging.basicConfig(level=logging.INFO)
53# Initialize ToolService and MCP Server
54tool_service = ToolService()
55mcp_app = Server("mcp-streamable-http-stateless")
57server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None)
59## ------------------------------ Event store ------------------------------
62@dataclass
63class EventEntry:
64 """
65 Represents an event entry in the event store.
66 """
68 event_id: EventId
69 stream_id: StreamId
70 message: JSONRPCMessage
73class InMemoryEventStore(EventStore):
74 """
75 Simple in-memory implementation of the EventStore interface for resumability.
76 This is primarily intended for examples and testing, not for production use
77 where a persistent storage solution would be more appropriate.
79 This implementation keeps only the last N events per stream for memory efficiency.
80 """
82 def __init__(self, max_events_per_stream: int = 100):
83 """Initialize the event store.
85 Args:
86 max_events_per_stream: Maximum number of events to keep per stream
87 """
88 self.max_events_per_stream = max_events_per_stream
89 # for maintaining last N events per stream
90 self.streams: dict[StreamId, deque[EventEntry]] = {}
91 # event_id -> EventEntry for quick lookup
92 self.event_index: dict[EventId, EventEntry] = {}
94 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
95 """Stores an event with a generated event ID."""
96 event_id = str(uuid4())
97 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message)
99 # Get or create deque for this stream
100 if stream_id not in self.streams:
101 self.streams[stream_id] = deque(maxlen=self.max_events_per_stream)
103 # If deque is full, the oldest event will be automatically removed
104 # We need to remove it from the event_index as well
105 if len(self.streams[stream_id]) == self.max_events_per_stream:
106 oldest_event = self.streams[stream_id][0]
107 self.event_index.pop(oldest_event.event_id, None)
109 # Add new event
110 self.streams[stream_id].append(event_entry)
111 self.event_index[event_id] = event_entry
113 return event_id
115 async def replay_events_after(
116 self,
117 last_event_id: EventId,
118 send_callback: EventCallback,
119 ) -> StreamId | None:
120 """Replays events that occurred after the specified event ID."""
121 if last_event_id not in self.event_index:
122 logger.warning(f"Event ID {last_event_id} not found in store")
123 return None
125 # Get the stream and find events after the last one
126 last_event = self.event_index[last_event_id]
127 stream_id = last_event.stream_id
128 stream_events = self.streams.get(last_event.stream_id, deque())
130 # Events in deque are already in chronological order
131 found_last = False
132 for event in stream_events:
133 if found_last:
134 await send_callback(EventMessage(event.message, event.event_id))
135 elif event.event_id == last_event_id: 135 ↛ 132line 135 didn't jump to line 132 because the condition on line 135 was always true
136 found_last = True
138 return stream_id
141## ------------------------------ Streamable HTTP Transport ------------------------------
144@asynccontextmanager
145async def get_db():
146 """
147 Asynchronous context manager for database sessions.
149 Yields:
150 A database session instance from SessionLocal.
151 Ensures the session is closed after use.
152 """
153 db = SessionLocal()
154 try:
155 yield db
156 finally:
157 db.close()
160@mcp_app.call_tool()
161async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
162 """
163 Handles tool invocation via the MCP Server.
165 Args:
166 name (str): The name of the tool to invoke.
167 arguments (dict): A dictionary of arguments to pass to the tool.
169 Returns:
170 List of content (TextContent, ImageContent, or EmbeddedResource) from the tool response.
171 Logs and returns an empty list on failure.
172 """
173 try:
174 async with get_db() as db:
175 result = await tool_service.invoke_tool(db, name, arguments)
176 if not result or not result.content:
177 logger.warning(f"No content returned by tool: {name}")
178 return []
180 return [types.TextContent(type=result.content[0].type, text=result.content[0].text)]
181 except Exception as e:
182 logger.exception(f"Error calling tool '{name}': {e}")
183 return []
186@mcp_app.list_tools()
187async def list_tools() -> List[types.Tool]:
188 """
189 Lists all tools available to the MCP Server.
191 Returns:
192 A list of Tool objects containing metadata such as name, description, and input schema.
193 Logs and returns an empty list on failure.
194 """
195 server_id = server_id_var.get()
197 if server_id:
198 try:
199 async with get_db() as db:
200 tools = await tool_service.list_server_tools(db, server_id)
201 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
202 except Exception as e:
203 logger.exception(f"Error listing tools:{e}")
204 return []
205 else:
206 try:
207 async with get_db() as db:
208 tools = await tool_service.list_tools(db)
209 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
210 except Exception as e:
211 logger.exception(f"Error listing tools:{e}")
212 return []
215class SessionManagerWrapper:
216 """
217 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
218 Provides start, stop, and request handling methods.
219 """
221 def __init__(self) -> None:
222 """
223 Initializes the session manager and the exit stack used for managing its lifecycle.
224 """
226 if settings.use_stateful_sessions: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true
227 event_store = InMemoryEventStore()
228 stateless = False
229 else:
230 event_store = None
231 stateless = True
233 self.session_manager = StreamableHTTPSessionManager(
234 app=mcp_app,
235 event_store=event_store,
236 json_response=settings.json_response_enabled,
237 stateless=stateless,
238 )
239 self.stack = AsyncExitStack()
241 async def initialize(self) -> None:
242 """
243 Starts the Streamable HTTP session manager context.
244 """
245 logger.info("Initializing Streamable HTTP service")
246 await self.stack.enter_async_context(self.session_manager.run())
248 async def shutdown(self) -> None:
249 """
250 Gracefully shuts down the Streamable HTTP session manager.
251 """
252 logger.info("Stopping Streamable HTTP Session Manager...")
253 await self.stack.aclose()
255 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
256 """
257 Forwards an incoming ASGI request to the streamable HTTP session manager.
259 Args:
260 scope (Scope): ASGI scope object containing connection information.
261 receive (Receive): ASGI receive callable.
262 send (Send): ASGI send callable.
263 Logs any exceptions that occur during request handling.
264 """
266 path = scope["modified_path"]
267 match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path)
269 if match:
270 server_id = match.group("server_id")
271 server_id_var.set(server_id)
273 try:
274 await self.session_manager.handle_request(scope, receive, send)
275 except Exception as e:
276 logger.exception(f"Error handling streamable HTTP request: {e}")
277 raise
280## ------------------------- Authentication for /mcp routes ------------------------------
283async def streamable_http_auth(scope, receive, send):
284 """
285 Perform authentication check in middleware context (ASGI scope).
287 If path does not end with "/mcp", just continue (return True).
289 Only check Authorization header for Bearer token.
290 If no Bearer token provided, allow (return True).
292 If auth_required is True and Bearer token provided, verify it.
293 If verification fails, send 401 JSONResponse and return False.
294 """
296 path = scope.get("path", "")
297 if not path.endswith("/mcp") and not path.endswith("/mcp/"):
298 # No auth needed for other paths in this middleware usage
299 return True
301 headers = Headers(scope=scope)
302 authorization = headers.get("authorization")
304 token = None
305 if authorization: 305 ↛ 309line 305 didn't jump to line 309 because the condition on line 305 was always true
306 scheme, credentials = get_authorization_scheme_param(authorization)
307 if scheme.lower() == "bearer" and credentials: 307 ↛ 309line 307 didn't jump to line 309 because the condition on line 307 was always true
308 token = credentials
309 try:
310 await verify_credentials(token)
311 except Exception:
312 response = JSONResponse(
313 {"detail": "Authentication failed"},
314 status_code=HTTP_401_UNAUTHORIZED,
315 headers={"WWW-Authenticate": "Bearer"},
316 )
317 await response(scope, receive, send)
318 return False
320 return True