Coverage for mcpgateway/wrapper.py: 86%
217 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 13:28 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 13:28 +0100
1# -*- coding: utf-8 -*-
2"""
3MCP Gateway Wrapper server.
5Copyright 2025
6SPDX-License-Identifier: Apache-2.0
7Authors: Keval Mahajan, Mihai Criveti, Madhav Kandukuri
9This module implements a wrapper bridge that facilitates
10interaction between the MCP client and the MCP gateway.
11It provides several functionalities, including listing tools,
12invoking tools, managing resources, retrieving prompts,
13and handling tool calls via the MCP gateway.
15A **stdio** bridge that exposes a remote MCP Gateway
16(HTTP-/JSON-RPC APIs) as a local MCP server. All JSON-RPC
17traffic is written to **stdout**; every log or trace message
18is emitted on **stderr** so that protocol messages and
19diagnostics never mix.
21Environment variables:
22- MCP_SERVER_CATALOG_URLS: Comma-separated list of gateway catalog URLs (required)
23- MCP_AUTH_TOKEN: Bearer token for the gateway (optional)
24- MCP_TOOL_CALL_TIMEOUT: Seconds to wait for a gateway RPC call (default 90)
25- MCP_WRAPPER_LOG_LEVEL: Python log level name or OFF/NONE to disable logging (default INFO)
27Example:
28 $ export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin --exp 10080 --secret my-test-key)
29 $ export MCP_AUTH_TOKEN=${MCPGATEWAY_BEARER_TOKEN}
30 $ export MCP_SERVER_CATALOG_URLS='http://localhost:4444/servers/1'
31 $ export MCP_TOOL_CALL_TIMEOUT=120
32 $ export MCP_WRAPPER_LOG_LEVEL=DEBUG # OFF to disable logging
33 $ python3 -m mcpgateway.wrapper
34"""
36import asyncio
37import logging
38import os
39import sys
40from typing import Any, Dict, List, Optional, Union
41from urllib.parse import urlparse
43import httpx
44import mcp.server.stdio
45from mcp import types
46from mcp.server import NotificationOptions, Server
47from mcp.server.models import InitializationOptions
48from pydantic import AnyUrl
50from mcpgateway import __version__
52# -----------------------------------------------------------------------------
53# Configuration
54# -----------------------------------------------------------------------------
55ENV_SERVER_CATALOGS = "MCP_SERVER_CATALOG_URLS"
56ENV_AUTH_TOKEN = "MCP_AUTH_TOKEN" # nosec B105 – this is an *environment variable name*, not a secret
57ENV_TIMEOUT = "MCP_TOOL_CALL_TIMEOUT"
58ENV_LOG_LEVEL = "MCP_WRAPPER_LOG_LEVEL"
60RAW_CATALOGS: str = os.getenv(ENV_SERVER_CATALOGS, "")
61SERVER_CATALOG_URLS: List[str] = [u.strip() for u in RAW_CATALOGS.split(",") if u.strip()]
63AUTH_TOKEN: str = os.getenv(ENV_AUTH_TOKEN, "")
64TOOL_CALL_TIMEOUT: int = int(os.getenv(ENV_TIMEOUT, "90"))
66# Validate required configuration
67if not SERVER_CATALOG_URLS: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 print(f"Error: {ENV_SERVER_CATALOGS} environment variable is required", file=sys.stderr)
69 sys.exit(1)
72# -----------------------------------------------------------------------------
73# Base URL Extraction
74# -----------------------------------------------------------------------------
75def _extract_base_url(url: str) -> str:
76 """Return the gateway-level base URL.
78 The function keeps any application root path (`APP_ROOT_PATH`) that the
79 remote gateway is mounted under (for example `/gateway`) while removing
80 the `/servers/<id>` suffix that appears in catalog endpoints. It also
81 discards any query string or fragment.
83 Args:
84 url (str): Full catalog URL, e.g.
85 `https://host.com/gateway/servers/1`.
87 Returns:
88 str: Clean base URL suitable for building `/tools/`, `/prompts/`,
89 or `/resources/` endpoints—for example
90 `https://host.com/gateway`.
92 Raises:
93 ValueError: If *url* lacks a scheme or network location.
95 Examples:
96 >>> _extract_base_url("https://host.com/servers/2")
97 'https://host.com'
98 >>> _extract_base_url("https://host.com/gateway/servers/2")
99 'https://host.com/gateway'
100 >>> _extract_base_url("https://host.com/gateway/servers")
101 'https://host.com/gateway'
102 >>> _extract_base_url("https://host.com/gateway")
103 'https://host.com/gateway'
105 Note:
106 If the target server was started with `APP_ROOT_PATH=/gateway`, the
107 resulting catalog URLs include that prefix. This helper preserves the
108 prefix so the wrapper's follow-up calls remain correctly scoped.
109 """
110 parsed = urlparse(url)
111 if not parsed.scheme or not parsed.netloc:
112 raise ValueError(f"Invalid URL provided: {url}")
114 path = parsed.path or ""
115 if "/servers/" in path:
116 path = path.split("/servers")[0] # ".../servers/123" -> "..."
117 elif path.endswith("/servers"):
118 path = path[: -len("/servers")] # ".../servers" -> "..."
119 # otherwise keep the existing path (supports APP_ROOT_PATH)
121 return f"{parsed.scheme}://{parsed.netloc}{path}"
124BASE_URL: str = _extract_base_url(SERVER_CATALOG_URLS[0]) if SERVER_CATALOG_URLS else ""
126# -----------------------------------------------------------------------------
127# Logging Setup
128# -----------------------------------------------------------------------------
129_log_level = os.getenv(ENV_LOG_LEVEL, "INFO").upper()
130if _log_level in {"OFF", "NONE", "DISABLE", "FALSE", "0"}: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 logging.disable(logging.CRITICAL)
132else:
133 logging.basicConfig(
134 level=getattr(logging, _log_level, logging.INFO),
135 format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
136 stream=sys.stderr,
137 )
139logger = logging.getLogger("mcpgateway.wrapper")
140logger.info(f"Starting MCP wrapper {__version__}: base_url={BASE_URL}, timeout={TOOL_CALL_TIMEOUT}")
143# -----------------------------------------------------------------------------
144# HTTP Helpers
145# -----------------------------------------------------------------------------
146async def fetch_url(url: str) -> httpx.Response:
147 """
148 Perform an asynchronous HTTP GET request and return the response.
150 Args:
151 url: The target URL to fetch.
153 Returns:
154 The successful ``httpx.Response`` object.
156 Raises:
157 httpx.RequestError: If a network problem occurs while making the request.
158 httpx.HTTPStatusError: If the server returns a 4xx or 5xx response.
159 """
160 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {}
161 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client:
162 try:
163 response = await client.get(url, headers=headers)
164 response.raise_for_status()
165 return response
166 except httpx.RequestError as err:
167 logger.error(f"Network error while fetching {url}: {err}")
168 raise
169 except httpx.HTTPStatusError as err:
170 logger.error(f"HTTP {err.response.status_code} returned for {url}: {err}")
171 raise
174# -----------------------------------------------------------------------------
175# Metadata Helpers
176# -----------------------------------------------------------------------------
177async def get_tools_from_mcp_server(catalog_urls: List[str]) -> List[str]:
178 """
179 Retrieve associated tool IDs from the MCP gateway server catalogs.
181 Args:
182 catalog_urls (List[str]): List of catalog endpoint URLs.
184 Returns:
185 List[str]: A list of tool ID strings extracted from the server catalog.
186 """
187 server_ids = [url.split("/")[-1] for url in catalog_urls]
188 url = f"{BASE_URL}/servers/"
189 response = await fetch_url(url)
190 catalog = response.json()
191 tool_ids: List[str] = []
192 for entry in catalog:
193 if str(entry.get("id")) in server_ids:
194 tool_ids.extend(entry.get("associatedTools", []))
195 return tool_ids
198async def tools_metadata(tool_ids: List[str]) -> List[Dict[str, Any]]:
199 """
200 Fetch metadata for a list of MCP tools by their IDs.
202 Args:
203 tool_ids (List[str]): List of tool ID strings.
205 Returns:
206 List[Dict[str, Any]]: A list of metadata dictionaries for each tool.
207 """
208 if not tool_ids: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 return []
210 url = f"{BASE_URL}/tools/"
211 response = await fetch_url(url)
212 data: List[Dict[str, Any]] = response.json()
213 if tool_ids == ["0"]:
214 return data
216 return [tool for tool in data if tool["id"] in tool_ids]
219async def get_prompts_from_mcp_server(catalog_urls: List[str]) -> List[str]:
220 """
221 Retrieve associated prompt IDs from the MCP gateway server catalogs.
223 Args:
224 catalog_urls (List[str]): List of catalog endpoint URLs.
226 Returns:
227 List[str]: A list of prompt ID strings.
228 """
229 server_ids = [url.split("/")[-1] for url in catalog_urls]
230 url = f"{BASE_URL}/servers/"
231 response = await fetch_url(url)
232 catalog = response.json()
233 prompt_ids: List[str] = []
234 for entry in catalog:
235 if str(entry.get("id")) in server_ids:
236 prompt_ids.extend(entry.get("associatedPrompts", []))
237 return prompt_ids
240async def prompts_metadata(prompt_ids: List[str]) -> List[Dict[str, Any]]:
241 """
242 Fetch metadata for a list of MCP prompts by their IDs.
244 Args:
245 prompt_ids (List[str]): List of prompt ID strings.
247 Returns:
248 List[Dict[str, Any]]: A list of metadata dictionaries for each prompt.
249 """
250 if not prompt_ids: 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true
251 return []
252 url = f"{BASE_URL}/prompts/"
253 response = await fetch_url(url)
254 data: List[Dict[str, Any]] = response.json()
255 if prompt_ids == ["0"]:
256 return data
257 return [pr for pr in data if str(pr.get("id")) in prompt_ids]
260async def get_resources_from_mcp_server(catalog_urls: List[str]) -> List[str]:
261 """
262 Retrieve associated resource IDs from the MCP gateway server catalogs.
264 Args:
265 catalog_urls (List[str]): List of catalog endpoint URLs.
267 Returns:
268 List[str]: A list of resource ID strings.
269 """
270 server_ids = [url.split("/")[-1] for url in catalog_urls]
271 url = f"{BASE_URL}/servers/"
272 response = await fetch_url(url)
273 catalog = response.json()
274 resource_ids: List[str] = []
275 for entry in catalog:
276 if str(entry.get("id")) in server_ids:
277 resource_ids.extend(entry.get("associatedResources", []))
278 return resource_ids
281async def resources_metadata(resource_ids: List[str]) -> List[Dict[str, Any]]:
282 """
283 Fetch metadata for a list of MCP resources by their IDs.
285 Args:
286 resource_ids (List[str]): List of resource ID strings.
288 Returns:
289 List[Dict[str, Any]]: A list of metadata dictionaries for each resource.
290 """
291 if not resource_ids: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 return []
293 url = f"{BASE_URL}/resources/"
294 response = await fetch_url(url)
295 data: List[Dict[str, Any]] = response.json()
296 if resource_ids == ["0"]:
297 return data
298 return [res for res in data if str(res.get("id")) in resource_ids]
301# -----------------------------------------------------------------------------
302# Server Handlers
303# -----------------------------------------------------------------------------
304server: Server = Server("mcpgateway-wrapper")
307@server.list_tools()
308async def handle_list_tools() -> List[types.Tool]:
309 """
310 List all available MCP tools exposed by the gateway.
312 Queries the configured server catalogs to retrieve tool IDs and then
313 fetches metadata for each tool to construct a list of Tool objects.
315 Returns:
316 List[types.Tool]: A list of Tool instances including name, description, and input schema.
318 Raises:
319 RuntimeError: If an error occurs during fetching or processing.
320 """
321 try:
322 tool_ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_tools_from_mcp_server(SERVER_CATALOG_URLS)
323 metadata = await tools_metadata(tool_ids)
324 tools = []
325 for tool in metadata:
326 tool_name = tool.get("name")
327 if tool_name: # Only include tools with valid names 327 ↛ 325line 327 didn't jump to line 325 because the condition on line 327 was always true
328 tools.append(
329 types.Tool(
330 name=str(tool_name),
331 description=tool.get("description", ""),
332 inputSchema=tool.get("inputSchema", {}),
333 )
334 )
335 return tools
336 except Exception as exc:
337 logger.exception("Error listing tools")
338 raise RuntimeError(f"Error listing tools: {exc}")
341@server.call_tool()
342async def handle_call_tool(name: str, arguments: Optional[Dict[str, Any]] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
343 """
344 Invoke a named MCP tool via the gateway's RPC endpoint.
346 Args:
347 name (str): The name of the tool to invoke.
348 arguments (Optional[Dict[str, Any]]): The arguments to pass to the tool method.
350 Returns:
351 List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
352 A list of content objects returned by the tool.
354 Raises:
355 ValueError: If tool call fails.
356 RuntimeError: If the HTTP request fails or returns an error.
357 """
358 if arguments is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 arguments = {}
361 logger.info(f"Calling tool {name} with args {arguments}")
362 payload = {"jsonrpc": "2.0", "id": 2, "method": name, "params": arguments}
363 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {}
365 try:
366 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client:
367 resp = await client.post(f"{BASE_URL}/rpc/", json=payload, headers=headers)
368 resp.raise_for_status()
369 result = resp.json()
371 if "error" in result:
372 error_msg = result["error"].get("message", "Unknown error")
373 raise ValueError(f"Tool call failed: {error_msg}")
375 tool_result = result.get("result", result)
376 return [types.TextContent(type="text", text=str(tool_result))]
378 except httpx.TimeoutException as exc:
379 logger.error(f"Timeout calling tool {name}: {exc}")
380 raise RuntimeError(f"Tool call timeout: {exc}")
381 except Exception as exc:
382 logger.exception(f"Error calling tool {name}")
383 raise RuntimeError(f"Error calling tool: {exc}")
386@server.list_resources()
387async def handle_list_resources() -> List[types.Resource]:
388 """
389 List all available MCP resources exposed by the gateway.
391 Fetches resource IDs from the configured catalogs and retrieves
392 metadata to construct Resource instances.
394 Returns:
395 List[types.Resource]: A list of Resource objects including URI, name, description, and MIME type.
397 Raises:
398 RuntimeError: If an error occurs during fetching or processing.
399 """
400 try:
401 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_resources_from_mcp_server(SERVER_CATALOG_URLS)
402 meta = await resources_metadata(ids)
403 resources = []
404 for r in meta:
405 uri = r.get("uri")
406 if not uri: 406 ↛ 407line 406 didn't jump to line 407 because the condition on line 406 was never true
407 logger.warning(f"Resource missing URI, skipping: {r}")
408 continue
409 try:
410 resources.append(
411 types.Resource(
412 uri=AnyUrl(uri),
413 name=r.get("name", ""),
414 description=r.get("description", ""),
415 mimeType=r.get("mimeType", "text/plain"),
416 )
417 )
418 except Exception as e:
419 logger.warning(f"Invalid resource URI {uri}: {e}")
420 continue
421 return resources
422 except Exception as exc:
423 logger.exception("Error listing resources")
424 raise RuntimeError(f"Error listing resources: {exc}")
427@server.read_resource()
428async def handle_read_resource(uri: AnyUrl) -> str:
429 """
430 Read and return the content of a resource by its URI.
432 Args:
433 uri (AnyUrl): The URI of the resource to read.
435 Returns:
436 str: The body text of the fetched resource.
438 Raises:
439 ValueError: If the resource cannot be fetched.
440 """
441 try:
442 response = await fetch_url(str(uri))
443 return response.text
444 except Exception as exc:
445 logger.exception(f"Error reading resource {uri}")
446 raise ValueError(f"Failed to read resource at {uri}: {exc}")
449@server.list_prompts()
450async def handle_list_prompts() -> List[types.Prompt]:
451 """
452 List all available MCP prompts exposed by the gateway.
454 Retrieves prompt IDs from the catalogs and fetches metadata
455 to create Prompt instances.
457 Returns:
458 List[types.Prompt]: A list of Prompt objects including name, description, and arguments.
460 Raises:
461 RuntimeError: If an error occurs during fetching or processing.
462 """
463 try:
464 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_prompts_from_mcp_server(SERVER_CATALOG_URLS)
465 meta = await prompts_metadata(ids)
466 prompts = []
467 for p in meta:
468 prompt_name = p.get("name")
469 if prompt_name: # Only include prompts with valid names 469 ↛ 467line 469 didn't jump to line 467 because the condition on line 469 was always true
470 prompts.append(
471 types.Prompt(
472 name=str(prompt_name),
473 description=p.get("description", ""),
474 arguments=p.get("arguments", []),
475 )
476 )
477 return prompts
478 except Exception as exc:
479 logger.exception("Error listing prompts")
480 raise RuntimeError(f"Error listing prompts: {exc}")
483@server.get_prompt()
484async def handle_get_prompt(name: str, arguments: Optional[Dict[str, str]] = None) -> types.GetPromptResult:
485 """
486 Retrieve and format a single prompt template with provided arguments.
488 Args:
489 name (str): The unique name of the prompt to fetch.
490 arguments (Optional[Dict[str, str]]): A mapping of placeholder names to replacement values.
492 Returns:
493 types.GetPromptResult: Contains description and list of formatted PromptMessage instances.
495 Raises:
496 ValueError: If fetching or formatting fails.
498 Example:
499 >>> await handle_get_prompt("greet", {"username": "Alice"})
500 """
501 try:
502 url = f"{BASE_URL}/prompts/{name}"
503 response = await fetch_url(url)
504 prompt_data = response.json()
506 template = prompt_data.get("template", "")
507 try:
508 formatted = template.format(**(arguments or {}))
509 except KeyError as exc:
510 raise ValueError(f"Missing placeholder in arguments: {exc}")
511 except Exception as exc:
512 raise ValueError(f"Error formatting prompt: {exc}")
514 return types.GetPromptResult(
515 description=prompt_data.get("description", ""),
516 messages=[
517 types.PromptMessage(
518 role="user",
519 content=types.TextContent(type="text", text=formatted),
520 )
521 ],
522 )
523 except ValueError:
524 raise
525 except Exception as exc:
526 logger.exception(f"Error getting prompt {name}")
527 raise ValueError(f"Failed to fetch prompt '{name}': {exc}")
530async def main() -> None:
531 """
532 Main entry point to start the MCP stdio server.
534 Initializes the server over standard IO, registers capabilities,
535 and begins listening for JSON-RPC messages.
537 This function should only be called in a script context.
539 Raises:
540 RuntimeError: If the server fails to start.
542 Example:
543 if __name__ == "__main__":
544 asyncio.run(main())
545 """
546 try:
547 async with mcp.server.stdio.stdio_server() as (reader, writer):
548 await server.run(
549 reader,
550 writer,
551 InitializationOptions(
552 server_name="mcpgateway-wrapper",
553 server_version=__version__,
554 capabilities=server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}),
555 ),
556 )
557 except Exception as exc:
558 logger.exception("Server failed to start")
559 raise RuntimeError(f"Server startup failed: {exc}")
562if __name__ == "__main__":
563 try:
564 asyncio.run(main())
565 except KeyboardInterrupt:
566 logger.info("Server interrupted by user")
567 except Exception:
568 logger.exception("Server failed")
569 sys.exit(1)
570 finally:
571 logger.info("Wrapper shutdown complete")