Coverage for mcpgateway/wrapper.py: 86%

217 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 13:28 +0100

1# -*- coding: utf-8 -*- 

2""" 

3MCP Gateway Wrapper server. 

4 

5Copyright 2025 

6SPDX-License-Identifier: Apache-2.0 

7Authors: Keval Mahajan, Mihai Criveti, Madhav Kandukuri 

8 

9This module implements a wrapper bridge that facilitates 

10interaction between the MCP client and the MCP gateway. 

11It provides several functionalities, including listing tools, 

12invoking tools, managing resources, retrieving prompts, 

13and handling tool calls via the MCP gateway. 

14 

15A **stdio** bridge that exposes a remote MCP Gateway 

16(HTTP-/JSON-RPC APIs) as a local MCP server. All JSON-RPC 

17traffic is written to **stdout**; every log or trace message 

18is emitted on **stderr** so that protocol messages and 

19diagnostics never mix. 

20 

21Environment variables: 

22- MCP_SERVER_CATALOG_URLS: Comma-separated list of gateway catalog URLs (required) 

23- MCP_AUTH_TOKEN: Bearer token for the gateway (optional) 

24- MCP_TOOL_CALL_TIMEOUT: Seconds to wait for a gateway RPC call (default 90) 

25- MCP_WRAPPER_LOG_LEVEL: Python log level name or OFF/NONE to disable logging (default INFO) 

26 

27Example: 

28 $ export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin --exp 10080 --secret my-test-key) 

29 $ export MCP_AUTH_TOKEN=${MCPGATEWAY_BEARER_TOKEN} 

30 $ export MCP_SERVER_CATALOG_URLS='http://localhost:4444/servers/1' 

31 $ export MCP_TOOL_CALL_TIMEOUT=120 

32 $ export MCP_WRAPPER_LOG_LEVEL=DEBUG # OFF to disable logging 

33 $ python3 -m mcpgateway.wrapper 

34""" 

35 

36import asyncio 

37import logging 

38import os 

39import sys 

40from typing import Any, Dict, List, Optional, Union 

41from urllib.parse import urlparse 

42 

43import httpx 

44import mcp.server.stdio 

45from mcp import types 

46from mcp.server import NotificationOptions, Server 

47from mcp.server.models import InitializationOptions 

48from pydantic import AnyUrl 

49 

50from mcpgateway import __version__ 

51 

52# ----------------------------------------------------------------------------- 

53# Configuration 

54# ----------------------------------------------------------------------------- 

55ENV_SERVER_CATALOGS = "MCP_SERVER_CATALOG_URLS" 

56ENV_AUTH_TOKEN = "MCP_AUTH_TOKEN" # nosec B105 – this is an *environment variable name*, not a secret 

57ENV_TIMEOUT = "MCP_TOOL_CALL_TIMEOUT" 

58ENV_LOG_LEVEL = "MCP_WRAPPER_LOG_LEVEL" 

59 

60RAW_CATALOGS: str = os.getenv(ENV_SERVER_CATALOGS, "") 

61SERVER_CATALOG_URLS: List[str] = [u.strip() for u in RAW_CATALOGS.split(",") if u.strip()] 

62 

63AUTH_TOKEN: str = os.getenv(ENV_AUTH_TOKEN, "") 

64TOOL_CALL_TIMEOUT: int = int(os.getenv(ENV_TIMEOUT, "90")) 

65 

66# Validate required configuration 

67if not SERVER_CATALOG_URLS: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 print(f"Error: {ENV_SERVER_CATALOGS} environment variable is required", file=sys.stderr) 

69 sys.exit(1) 

70 

71 

72# ----------------------------------------------------------------------------- 

73# Base URL Extraction 

74# ----------------------------------------------------------------------------- 

75def _extract_base_url(url: str) -> str: 

76 """Return the gateway-level base URL. 

77 

78 The function keeps any application root path (`APP_ROOT_PATH`) that the 

79 remote gateway is mounted under (for example `/gateway`) while removing 

80 the `/servers/<id>` suffix that appears in catalog endpoints. It also 

81 discards any query string or fragment. 

82 

83 Args: 

84 url (str): Full catalog URL, e.g. 

85 `https://host.com/gateway/servers/1`. 

86 

87 Returns: 

88 str: Clean base URL suitable for building `/tools/`, `/prompts/`, 

89 or `/resources/` endpoints—for example 

90 `https://host.com/gateway`. 

91 

92 Raises: 

93 ValueError: If *url* lacks a scheme or network location. 

94 

95 Examples: 

96 >>> _extract_base_url("https://host.com/servers/2") 

97 'https://host.com' 

98 >>> _extract_base_url("https://host.com/gateway/servers/2") 

99 'https://host.com/gateway' 

100 >>> _extract_base_url("https://host.com/gateway/servers") 

101 'https://host.com/gateway' 

102 >>> _extract_base_url("https://host.com/gateway") 

103 'https://host.com/gateway' 

104 

105 Note: 

106 If the target server was started with `APP_ROOT_PATH=/gateway`, the 

107 resulting catalog URLs include that prefix. This helper preserves the 

108 prefix so the wrapper's follow-up calls remain correctly scoped. 

109 """ 

110 parsed = urlparse(url) 

111 if not parsed.scheme or not parsed.netloc: 

112 raise ValueError(f"Invalid URL provided: {url}") 

113 

114 path = parsed.path or "" 

115 if "/servers/" in path: 

116 path = path.split("/servers")[0] # ".../servers/123" -> "..." 

117 elif path.endswith("/servers"): 

118 path = path[: -len("/servers")] # ".../servers" -> "..." 

119 # otherwise keep the existing path (supports APP_ROOT_PATH) 

120 

121 return f"{parsed.scheme}://{parsed.netloc}{path}" 

122 

123 

124BASE_URL: str = _extract_base_url(SERVER_CATALOG_URLS[0]) if SERVER_CATALOG_URLS else "" 

125 

126# ----------------------------------------------------------------------------- 

127# Logging Setup 

128# ----------------------------------------------------------------------------- 

129_log_level = os.getenv(ENV_LOG_LEVEL, "INFO").upper() 

130if _log_level in {"OFF", "NONE", "DISABLE", "FALSE", "0"}: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 logging.disable(logging.CRITICAL) 

132else: 

133 logging.basicConfig( 

134 level=getattr(logging, _log_level, logging.INFO), 

135 format="%(asctime)s %(levelname)-8s %(name)s: %(message)s", 

136 stream=sys.stderr, 

137 ) 

138 

139logger = logging.getLogger("mcpgateway.wrapper") 

140logger.info(f"Starting MCP wrapper {__version__}: base_url={BASE_URL}, timeout={TOOL_CALL_TIMEOUT}") 

141 

142 

143# ----------------------------------------------------------------------------- 

144# HTTP Helpers 

145# ----------------------------------------------------------------------------- 

146async def fetch_url(url: str) -> httpx.Response: 

147 """ 

148 Perform an asynchronous HTTP GET request and return the response. 

149 

150 Args: 

151 url: The target URL to fetch. 

152 

153 Returns: 

154 The successful ``httpx.Response`` object. 

155 

156 Raises: 

157 httpx.RequestError: If a network problem occurs while making the request. 

158 httpx.HTTPStatusError: If the server returns a 4xx or 5xx response. 

159 """ 

160 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {} 

161 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client: 

162 try: 

163 response = await client.get(url, headers=headers) 

164 response.raise_for_status() 

165 return response 

166 except httpx.RequestError as err: 

167 logger.error(f"Network error while fetching {url}: {err}") 

168 raise 

169 except httpx.HTTPStatusError as err: 

170 logger.error(f"HTTP {err.response.status_code} returned for {url}: {err}") 

171 raise 

172 

173 

174# ----------------------------------------------------------------------------- 

175# Metadata Helpers 

176# ----------------------------------------------------------------------------- 

177async def get_tools_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

178 """ 

179 Retrieve associated tool IDs from the MCP gateway server catalogs. 

180 

181 Args: 

182 catalog_urls (List[str]): List of catalog endpoint URLs. 

183 

184 Returns: 

185 List[str]: A list of tool ID strings extracted from the server catalog. 

186 """ 

187 server_ids = [url.split("/")[-1] for url in catalog_urls] 

188 url = f"{BASE_URL}/servers/" 

189 response = await fetch_url(url) 

190 catalog = response.json() 

191 tool_ids: List[str] = [] 

192 for entry in catalog: 

193 if str(entry.get("id")) in server_ids: 

194 tool_ids.extend(entry.get("associatedTools", [])) 

195 return tool_ids 

196 

197 

198async def tools_metadata(tool_ids: List[str]) -> List[Dict[str, Any]]: 

199 """ 

200 Fetch metadata for a list of MCP tools by their IDs. 

201 

202 Args: 

203 tool_ids (List[str]): List of tool ID strings. 

204 

205 Returns: 

206 List[Dict[str, Any]]: A list of metadata dictionaries for each tool. 

207 """ 

208 if not tool_ids: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 return [] 

210 url = f"{BASE_URL}/tools/" 

211 response = await fetch_url(url) 

212 data: List[Dict[str, Any]] = response.json() 

213 if tool_ids == ["0"]: 

214 return data 

215 

216 return [tool for tool in data if tool["id"] in tool_ids] 

217 

218 

219async def get_prompts_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

220 """ 

221 Retrieve associated prompt IDs from the MCP gateway server catalogs. 

222 

223 Args: 

224 catalog_urls (List[str]): List of catalog endpoint URLs. 

225 

226 Returns: 

227 List[str]: A list of prompt ID strings. 

228 """ 

229 server_ids = [url.split("/")[-1] for url in catalog_urls] 

230 url = f"{BASE_URL}/servers/" 

231 response = await fetch_url(url) 

232 catalog = response.json() 

233 prompt_ids: List[str] = [] 

234 for entry in catalog: 

235 if str(entry.get("id")) in server_ids: 

236 prompt_ids.extend(entry.get("associatedPrompts", [])) 

237 return prompt_ids 

238 

239 

240async def prompts_metadata(prompt_ids: List[str]) -> List[Dict[str, Any]]: 

241 """ 

242 Fetch metadata for a list of MCP prompts by their IDs. 

243 

244 Args: 

245 prompt_ids (List[str]): List of prompt ID strings. 

246 

247 Returns: 

248 List[Dict[str, Any]]: A list of metadata dictionaries for each prompt. 

249 """ 

250 if not prompt_ids: 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true

251 return [] 

252 url = f"{BASE_URL}/prompts/" 

253 response = await fetch_url(url) 

254 data: List[Dict[str, Any]] = response.json() 

255 if prompt_ids == ["0"]: 

256 return data 

257 return [pr for pr in data if str(pr.get("id")) in prompt_ids] 

258 

259 

260async def get_resources_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

261 """ 

262 Retrieve associated resource IDs from the MCP gateway server catalogs. 

263 

264 Args: 

265 catalog_urls (List[str]): List of catalog endpoint URLs. 

266 

267 Returns: 

268 List[str]: A list of resource ID strings. 

269 """ 

270 server_ids = [url.split("/")[-1] for url in catalog_urls] 

271 url = f"{BASE_URL}/servers/" 

272 response = await fetch_url(url) 

273 catalog = response.json() 

274 resource_ids: List[str] = [] 

275 for entry in catalog: 

276 if str(entry.get("id")) in server_ids: 

277 resource_ids.extend(entry.get("associatedResources", [])) 

278 return resource_ids 

279 

280 

281async def resources_metadata(resource_ids: List[str]) -> List[Dict[str, Any]]: 

282 """ 

283 Fetch metadata for a list of MCP resources by their IDs. 

284 

285 Args: 

286 resource_ids (List[str]): List of resource ID strings. 

287 

288 Returns: 

289 List[Dict[str, Any]]: A list of metadata dictionaries for each resource. 

290 """ 

291 if not resource_ids: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 return [] 

293 url = f"{BASE_URL}/resources/" 

294 response = await fetch_url(url) 

295 data: List[Dict[str, Any]] = response.json() 

296 if resource_ids == ["0"]: 

297 return data 

298 return [res for res in data if str(res.get("id")) in resource_ids] 

299 

300 

301# ----------------------------------------------------------------------------- 

302# Server Handlers 

303# ----------------------------------------------------------------------------- 

304server: Server = Server("mcpgateway-wrapper") 

305 

306 

307@server.list_tools() 

308async def handle_list_tools() -> List[types.Tool]: 

309 """ 

310 List all available MCP tools exposed by the gateway. 

311 

312 Queries the configured server catalogs to retrieve tool IDs and then 

313 fetches metadata for each tool to construct a list of Tool objects. 

314 

315 Returns: 

316 List[types.Tool]: A list of Tool instances including name, description, and input schema. 

317 

318 Raises: 

319 RuntimeError: If an error occurs during fetching or processing. 

320 """ 

321 try: 

322 tool_ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_tools_from_mcp_server(SERVER_CATALOG_URLS) 

323 metadata = await tools_metadata(tool_ids) 

324 tools = [] 

325 for tool in metadata: 

326 tool_name = tool.get("name") 

327 if tool_name: # Only include tools with valid names 327 ↛ 325line 327 didn't jump to line 325 because the condition on line 327 was always true

328 tools.append( 

329 types.Tool( 

330 name=str(tool_name), 

331 description=tool.get("description", ""), 

332 inputSchema=tool.get("inputSchema", {}), 

333 ) 

334 ) 

335 return tools 

336 except Exception as exc: 

337 logger.exception("Error listing tools") 

338 raise RuntimeError(f"Error listing tools: {exc}") 

339 

340 

341@server.call_tool() 

342async def handle_call_tool(name: str, arguments: Optional[Dict[str, Any]] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

343 """ 

344 Invoke a named MCP tool via the gateway's RPC endpoint. 

345 

346 Args: 

347 name (str): The name of the tool to invoke. 

348 arguments (Optional[Dict[str, Any]]): The arguments to pass to the tool method. 

349 

350 Returns: 

351 List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

352 A list of content objects returned by the tool. 

353 

354 Raises: 

355 ValueError: If tool call fails. 

356 RuntimeError: If the HTTP request fails or returns an error. 

357 """ 

358 if arguments is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 arguments = {} 

360 

361 logger.info(f"Calling tool {name} with args {arguments}") 

362 payload = {"jsonrpc": "2.0", "id": 2, "method": name, "params": arguments} 

363 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {} 

364 

365 try: 

366 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client: 

367 resp = await client.post(f"{BASE_URL}/rpc/", json=payload, headers=headers) 

368 resp.raise_for_status() 

369 result = resp.json() 

370 

371 if "error" in result: 

372 error_msg = result["error"].get("message", "Unknown error") 

373 raise ValueError(f"Tool call failed: {error_msg}") 

374 

375 tool_result = result.get("result", result) 

376 return [types.TextContent(type="text", text=str(tool_result))] 

377 

378 except httpx.TimeoutException as exc: 

379 logger.error(f"Timeout calling tool {name}: {exc}") 

380 raise RuntimeError(f"Tool call timeout: {exc}") 

381 except Exception as exc: 

382 logger.exception(f"Error calling tool {name}") 

383 raise RuntimeError(f"Error calling tool: {exc}") 

384 

385 

386@server.list_resources() 

387async def handle_list_resources() -> List[types.Resource]: 

388 """ 

389 List all available MCP resources exposed by the gateway. 

390 

391 Fetches resource IDs from the configured catalogs and retrieves 

392 metadata to construct Resource instances. 

393 

394 Returns: 

395 List[types.Resource]: A list of Resource objects including URI, name, description, and MIME type. 

396 

397 Raises: 

398 RuntimeError: If an error occurs during fetching or processing. 

399 """ 

400 try: 

401 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_resources_from_mcp_server(SERVER_CATALOG_URLS) 

402 meta = await resources_metadata(ids) 

403 resources = [] 

404 for r in meta: 

405 uri = r.get("uri") 

406 if not uri: 406 ↛ 407line 406 didn't jump to line 407 because the condition on line 406 was never true

407 logger.warning(f"Resource missing URI, skipping: {r}") 

408 continue 

409 try: 

410 resources.append( 

411 types.Resource( 

412 uri=AnyUrl(uri), 

413 name=r.get("name", ""), 

414 description=r.get("description", ""), 

415 mimeType=r.get("mimeType", "text/plain"), 

416 ) 

417 ) 

418 except Exception as e: 

419 logger.warning(f"Invalid resource URI {uri}: {e}") 

420 continue 

421 return resources 

422 except Exception as exc: 

423 logger.exception("Error listing resources") 

424 raise RuntimeError(f"Error listing resources: {exc}") 

425 

426 

427@server.read_resource() 

428async def handle_read_resource(uri: AnyUrl) -> str: 

429 """ 

430 Read and return the content of a resource by its URI. 

431 

432 Args: 

433 uri (AnyUrl): The URI of the resource to read. 

434 

435 Returns: 

436 str: The body text of the fetched resource. 

437 

438 Raises: 

439 ValueError: If the resource cannot be fetched. 

440 """ 

441 try: 

442 response = await fetch_url(str(uri)) 

443 return response.text 

444 except Exception as exc: 

445 logger.exception(f"Error reading resource {uri}") 

446 raise ValueError(f"Failed to read resource at {uri}: {exc}") 

447 

448 

449@server.list_prompts() 

450async def handle_list_prompts() -> List[types.Prompt]: 

451 """ 

452 List all available MCP prompts exposed by the gateway. 

453 

454 Retrieves prompt IDs from the catalogs and fetches metadata 

455 to create Prompt instances. 

456 

457 Returns: 

458 List[types.Prompt]: A list of Prompt objects including name, description, and arguments. 

459 

460 Raises: 

461 RuntimeError: If an error occurs during fetching or processing. 

462 """ 

463 try: 

464 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_prompts_from_mcp_server(SERVER_CATALOG_URLS) 

465 meta = await prompts_metadata(ids) 

466 prompts = [] 

467 for p in meta: 

468 prompt_name = p.get("name") 

469 if prompt_name: # Only include prompts with valid names 469 ↛ 467line 469 didn't jump to line 467 because the condition on line 469 was always true

470 prompts.append( 

471 types.Prompt( 

472 name=str(prompt_name), 

473 description=p.get("description", ""), 

474 arguments=p.get("arguments", []), 

475 ) 

476 ) 

477 return prompts 

478 except Exception as exc: 

479 logger.exception("Error listing prompts") 

480 raise RuntimeError(f"Error listing prompts: {exc}") 

481 

482 

483@server.get_prompt() 

484async def handle_get_prompt(name: str, arguments: Optional[Dict[str, str]] = None) -> types.GetPromptResult: 

485 """ 

486 Retrieve and format a single prompt template with provided arguments. 

487 

488 Args: 

489 name (str): The unique name of the prompt to fetch. 

490 arguments (Optional[Dict[str, str]]): A mapping of placeholder names to replacement values. 

491 

492 Returns: 

493 types.GetPromptResult: Contains description and list of formatted PromptMessage instances. 

494 

495 Raises: 

496 ValueError: If fetching or formatting fails. 

497 

498 Example: 

499 >>> await handle_get_prompt("greet", {"username": "Alice"}) 

500 """ 

501 try: 

502 url = f"{BASE_URL}/prompts/{name}" 

503 response = await fetch_url(url) 

504 prompt_data = response.json() 

505 

506 template = prompt_data.get("template", "") 

507 try: 

508 formatted = template.format(**(arguments or {})) 

509 except KeyError as exc: 

510 raise ValueError(f"Missing placeholder in arguments: {exc}") 

511 except Exception as exc: 

512 raise ValueError(f"Error formatting prompt: {exc}") 

513 

514 return types.GetPromptResult( 

515 description=prompt_data.get("description", ""), 

516 messages=[ 

517 types.PromptMessage( 

518 role="user", 

519 content=types.TextContent(type="text", text=formatted), 

520 ) 

521 ], 

522 ) 

523 except ValueError: 

524 raise 

525 except Exception as exc: 

526 logger.exception(f"Error getting prompt {name}") 

527 raise ValueError(f"Failed to fetch prompt '{name}': {exc}") 

528 

529 

530async def main() -> None: 

531 """ 

532 Main entry point to start the MCP stdio server. 

533 

534 Initializes the server over standard IO, registers capabilities, 

535 and begins listening for JSON-RPC messages. 

536 

537 This function should only be called in a script context. 

538 

539 Raises: 

540 RuntimeError: If the server fails to start. 

541 

542 Example: 

543 if __name__ == "__main__": 

544 asyncio.run(main()) 

545 """ 

546 try: 

547 async with mcp.server.stdio.stdio_server() as (reader, writer): 

548 await server.run( 

549 reader, 

550 writer, 

551 InitializationOptions( 

552 server_name="mcpgateway-wrapper", 

553 server_version=__version__, 

554 capabilities=server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}), 

555 ), 

556 ) 

557 except Exception as exc: 

558 logger.exception("Server failed to start") 

559 raise RuntimeError(f"Server startup failed: {exc}") 

560 

561 

562if __name__ == "__main__": 

563 try: 

564 asyncio.run(main()) 

565 except KeyboardInterrupt: 

566 logger.info("Server interrupted by user") 

567 except Exception: 

568 logger.exception("Server failed") 

569 sys.exit(1) 

570 finally: 

571 logger.info("Wrapper shutdown complete")