Coverage for mcpgateway/transports/websocket_transport.py: 74%

71 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 12:53 +0100

1# -*- coding: utf-8 -*- 

2"""WebSocket Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements WebSocket transport for MCP, providing 

9full-duplex communication between client and server. 

10""" 

11 

12import asyncio 

13import logging 

14from typing import Any, AsyncGenerator, Dict, Optional 

15 

16from fastapi import WebSocket, WebSocketDisconnect 

17 

18from mcpgateway.config import settings 

19from mcpgateway.transports.base import Transport 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class WebSocketTransport(Transport): 

25 """Transport implementation using WebSocket.""" 

26 

27 def __init__(self, websocket: WebSocket): 

28 """Initialize WebSocket transport. 

29 

30 Args: 

31 websocket: FastAPI WebSocket connection 

32 """ 

33 self._websocket = websocket 

34 self._connected = False 

35 self._ping_task: Optional[asyncio.Task] = None 

36 

37 async def connect(self) -> None: 

38 """Set up WebSocket connection.""" 

39 await self._websocket.accept() 

40 self._connected = True 

41 

42 # Start ping task 

43 if settings.websocket_ping_interval > 0: 43 ↛ 46line 43 didn't jump to line 46 because the condition on line 43 was always true

44 self._ping_task = asyncio.create_task(self._ping_loop()) 

45 

46 logger.info("WebSocket transport connected") 

47 

48 async def disconnect(self) -> None: 

49 """Clean up WebSocket connection.""" 

50 if self._ping_task: 50 ↛ 57line 50 didn't jump to line 57 because the condition on line 50 was always true

51 self._ping_task.cancel() 

52 try: 

53 await self._ping_task 

54 except asyncio.CancelledError: 

55 pass 

56 

57 if self._connected: 

58 await self._websocket.close() 

59 self._connected = False 

60 logger.info("WebSocket transport disconnected") 

61 

62 async def send_message(self, message: Dict[str, Any]) -> None: 

63 """Send a message over WebSocket. 

64 

65 Args: 

66 message: Message to send 

67 

68 Raises: 

69 RuntimeError: If transport is not connected 

70 Exception: If unable to send json to websocket 

71 """ 

72 if not self._connected: 

73 raise RuntimeError("Transport not connected") 

74 

75 try: 

76 await self._websocket.send_json(message) 

77 except Exception as e: 

78 logger.error(f"Failed to send message: {e}") 

79 raise 

80 

81 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

82 """Receive messages from WebSocket. 

83 

84 Yields: 

85 Received messages 

86 

87 Raises: 

88 RuntimeError: If transport is not connected 

89 """ 

90 if not self._connected: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 raise RuntimeError("Transport not connected") 

92 

93 try: 

94 while True: 

95 message = await self._websocket.receive_json() 

96 yield message 

97 

98 except WebSocketDisconnect: 

99 logger.info("WebSocket client disconnected") 

100 self._connected = False 

101 except Exception as e: 

102 logger.error(f"Error receiving message: {e}") 

103 self._connected = False 

104 finally: 

105 await self.disconnect() 

106 

107 async def is_connected(self) -> bool: 

108 """Check if transport is connected. 

109 

110 Returns: 

111 True if connected 

112 """ 

113 return self._connected 

114 

115 async def _ping_loop(self) -> None: 

116 """Send periodic ping messages to keep connection alive.""" 

117 try: 

118 while self._connected: 118 ↛ 134line 118 didn't jump to line 134 because the condition on line 118 was always true

119 await asyncio.sleep(settings.websocket_ping_interval) 

120 await self._websocket.send_bytes(b"ping") 

121 try: 

122 resp = await asyncio.wait_for( 

123 self._websocket.receive_bytes(), 

124 timeout=settings.websocket_ping_interval / 2, 

125 ) 

126 if resp != b"pong": 

127 logger.warning("Invalid ping response") 

128 except asyncio.TimeoutError: 

129 logger.warning("Ping timeout") 

130 break 

131 except Exception as e: 

132 logger.error(f"Ping loop error: {e}") 

133 finally: 

134 await self.disconnect() 

135 

136 async def send_ping(self) -> None: 

137 """Send a manual ping message.""" 

138 if self._connected: 138 ↛ exitline 138 didn't return from function 'send_ping' because the condition on line 138 was always true

139 await self._websocket.send_bytes(b"ping")