Coverage for mcpgateway/cache/session_registry.py: 26%

416 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:47 +0100

1# -*- coding: utf-8 -*- 

2"""Session Registry with optional distributed state. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module provides a registry for SSE sessions with support for distributed deployment 

9using Redis or SQLAlchemy as optional backends for shared state between workers. 

10""" 

11 

12import asyncio 

13import json 

14import logging 

15import time 

16from typing import Any, Dict, Optional 

17 

18import httpx 

19from fastapi import HTTPException, status 

20 

21from mcpgateway.config import settings 

22from mcpgateway.db import SessionMessageRecord, SessionRecord, get_db 

23from mcpgateway.services import PromptService, ResourceService, ToolService 

24from mcpgateway.transports import SSETransport 

25from mcpgateway.types import Implementation, InitializeResult, ServerCapabilities 

26 

27logger = logging.getLogger(__name__) 

28 

29tool_service = ToolService() 

30resource_service = ResourceService() 

31prompt_service = PromptService() 

32 

33try: 

34 from redis.asyncio import Redis 

35 

36 REDIS_AVAILABLE = True 

37except ImportError: 

38 REDIS_AVAILABLE = False 

39 

40try: 

41 from sqlalchemy import func 

42 

43 SQLALCHEMY_AVAILABLE = True 

44except ImportError: 

45 SQLALCHEMY_AVAILABLE = False 

46 

47 

48class SessionBackend: 

49 """Session backend related fields""" 

50 

51 def __init__( 

52 self, 

53 backend: str = "memory", 

54 redis_url: Optional[str] = None, 

55 database_url: Optional[str] = None, 

56 session_ttl: int = 3600, # 1 hour 

57 message_ttl: int = 600, # 10 min 

58 ): 

59 """Initialize session registry. 

60 

61 Args: 

62 backend: "memory", "redis", "database", or "none" 

63 redis_url: Redis connection URL (required for redis backend) 

64 database_url: Database connection URL (required for database backend) 

65 session_ttl: Session time-to-live in seconds 

66 message_ttl: Message time-to-live in seconds 

67 

68 Raises: 

69 ValueError: If backend is invalid or required URL is missing 

70 """ 

71 

72 self._backend = backend.lower() 

73 self._session_ttl = session_ttl 

74 self._message_ttl = message_ttl 

75 

76 # Set up backend-specific components 

77 if self._backend == "memory": 

78 # Nothing special needed for memory backend 

79 self._session_message = None 

80 

81 elif self._backend == "none": 81 ↛ 83line 81 didn't jump to line 83 because the condition on line 81 was never true

82 # No session tracking - this is just a dummy registry 

83 logger.info("Session registry initialized with 'none' backend - session tracking disabled") 

84 

85 elif self._backend == "redis": 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 if not REDIS_AVAILABLE: 

87 raise ValueError("Redis backend requested but redis package not installed") 

88 if not redis_url: 

89 raise ValueError("Redis backend requires redis_url") 

90 

91 self._redis = Redis.from_url(redis_url) 

92 self._pubsub = self._redis.pubsub() 

93 self._pubsub.subscribe("mcp_session_events") 

94 

95 elif self._backend == "database": 95 ↛ 101line 95 didn't jump to line 101 because the condition on line 95 was always true

96 if not SQLALCHEMY_AVAILABLE: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 raise ValueError("Database backend requested but SQLAlchemy not installed") 

98 if not database_url: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 raise ValueError("Database backend requires database_url") 

100 else: 

101 raise ValueError(f"Invalid backend: {backend}") 

102 

103 

104class SessionRegistry(SessionBackend): 

105 """Registry for SSE sessions with optional distributed state. 

106 

107 Supports three backend modes: 

108 - memory: In-memory storage (default, no dependencies) 

109 - redis: Redis-backed shared storage 

110 - database: SQLAlchemy-backed shared storage 

111 

112 In distributed mode (redis/database), session existence is tracked in the shared 

113 backend while transports themselves remain local to each worker process. 

114 """ 

115 

116 def __init__( 

117 self, 

118 backend: str = "memory", 

119 redis_url: Optional[str] = None, 

120 database_url: Optional[str] = None, 

121 session_ttl: int = 3600, # 1 hour 

122 message_ttl: int = 600, # 10 min 

123 ): 

124 """Initialize session registry. 

125 

126 Args: 

127 backend: "memory", "redis", "database", or "none" 

128 redis_url: Redis connection URL (required for redis backend) 

129 database_url: Database connection URL (required for database backend) 

130 session_ttl: Session time-to-live in seconds 

131 message_ttl: Message time-to-live in seconds 

132 """ 

133 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl) 

134 self._sessions: Dict[str, Any] = {} # Local transport cache 

135 self._lock = asyncio.Lock() 

136 self._cleanup_task = None 

137 

138 async def initialize(self) -> None: 

139 """Initialize the registry with async setup. 

140 

141 Call this during application startup. 

142 """ 

143 logger.info(f"Initializing session registry with backend: {self._backend}") 

144 

145 if self._backend == "database": 145 ↛ 147line 145 didn't jump to line 147 because the condition on line 145 was never true

146 # Start database cleanup task 

147 self._cleanup_task = asyncio.create_task(self._db_cleanup_task()) 

148 logger.info("Database cleanup task started") 

149 

150 elif self._backend == "none": 150 ↛ 152line 150 didn't jump to line 152 because the condition on line 150 was never true

151 # Nothing to initialize for none backend 

152 pass 

153 

154 # Memory backend needs session cleanup 

155 elif self._backend == "memory": 155 ↛ exitline 155 didn't return from function 'initialize' because the condition on line 155 was always true

156 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task()) 

157 logger.info("Memory cleanup task started") 

158 

159 async def shutdown(self) -> None: 

160 """Shutdown the registry. 

161 

162 Call this during application shutdown. 

163 """ 

164 logger.info("Shutting down session registry") 

165 

166 # Cancel cleanup task 

167 if self._cleanup_task: 167 ↛ 175line 167 didn't jump to line 175 because the condition on line 167 was always true

168 self._cleanup_task.cancel() 

169 try: 

170 await self._cleanup_task 

171 except asyncio.CancelledError: 

172 pass 

173 

174 # Close Redis connections 

175 if self._backend == "redis": 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 try: 

177 self._pubsub.close() 

178 self._redis.close() 

179 except Exception as e: 

180 logger.error(f"Error closing Redis connection: {e}") 

181 

182 async def add_session(self, session_id: str, transport: SSETransport) -> None: 

183 """Add a session to the registry. 

184 

185 Args: 

186 session_id: Unique session identifier 

187 transport: Transport session 

188 """ 

189 # Skip for none backend 

190 if self._backend == "none": 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 return 

192 

193 async with self._lock: 

194 self._sessions[session_id] = transport 

195 

196 if self._backend == "redis": 196 ↛ 198line 196 didn't jump to line 198 because the condition on line 196 was never true

197 # Store session marker in Redis 

198 try: 

199 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1") 

200 # Publish event to notify other workers 

201 await self._redis.publish("mcp_session_events", json.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()})) 

202 except Exception as e: 

203 logger.error(f"Redis error adding session {session_id}: {e}") 

204 

205 elif self._backend == "database": 205 ↛ 207line 205 didn't jump to line 207 because the condition on line 205 was never true

206 # Store session in database 

207 try: 

208 

209 def _db_add(): 

210 db_session = next(get_db()) 

211 try: 

212 session_record = SessionRecord(session_id=session_id) 

213 db_session.add(session_record) 

214 db_session.commit() 

215 except Exception as ex: 

216 db_session.rollback() 

217 raise ex 

218 finally: 

219 db_session.close() 

220 

221 await asyncio.to_thread(_db_add) 

222 except Exception as e: 

223 logger.error(f"Database error adding session {session_id}: {e}") 

224 

225 logger.info(f"Added session: {session_id}") 

226 

227 async def get_session(self, session_id: str) -> Any: 

228 """Get session by ID. 

229 

230 Args: 

231 session_id: Session identifier 

232 

233 Returns: 

234 Transport object or None if not found 

235 """ 

236 # Skip for none backend 

237 if self._backend == "none": 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 return None 

239 

240 # First check local cache 

241 async with self._lock: 

242 transport = self._sessions.get(session_id) 

243 if transport: 243 ↛ 248line 243 didn't jump to line 248

244 logger.info(f"Session {session_id} exists in local cache") 

245 return transport 

246 

247 # If not in local cache, check if it exists in shared backend 

248 if self._backend == "redis": 

249 try: 

250 exists = await self._redis.exists(f"mcp:session:{session_id}") 

251 session_exists = bool(exists) 

252 if session_exists: 

253 logger.info(f"Session {session_id} exists in Redis but not in local cache") 

254 return None # We don't have the transport locally 

255 except Exception as e: 

256 logger.error(f"Redis error checking session {session_id}: {e}") 

257 return None 

258 

259 elif self._backend == "database": 

260 try: 

261 

262 def _db_check(): 

263 db_session = next(get_db()) 

264 try: 

265 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

266 return record is not None 

267 finally: 

268 db_session.close() 

269 

270 exists = await asyncio.to_thread(_db_check) 

271 if exists: 

272 logger.info(f"Session {session_id} exists in database but not in local cache") 

273 return None 

274 except Exception as e: 

275 logger.error(f"Database error checking session {session_id}: {e}") 

276 return None 

277 

278 return None 

279 

280 async def remove_session(self, session_id: str) -> None: 

281 """Remove a session from the registry. 

282 

283 Args: 

284 session_id: Session identifier 

285 """ 

286 # Skip for none backend 

287 if self._backend == "none": 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 return 

289 

290 # Clean up local transport 

291 transport = None 

292 async with self._lock: 

293 if session_id in self._sessions: 293 ↛ 297line 293 didn't jump to line 297

294 transport = self._sessions.pop(session_id) 

295 

296 # Disconnect transport if found 

297 if transport: 297 ↛ 304line 297 didn't jump to line 304 because the condition on line 297 was always true

298 try: 

299 await transport.disconnect() 

300 except Exception as e: 

301 logger.error(f"Error disconnecting transport for session {session_id}: {e}") 

302 

303 # Remove from shared backend 

304 if self._backend == "redis": 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 try: 

306 await self._redis.delete(f"mcp:session:{session_id}") 

307 # Notify other workers 

308 await self._redis.publish("mcp_session_events", json.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()})) 

309 except Exception as e: 

310 logger.error(f"Redis error removing session {session_id}: {e}") 

311 

312 elif self._backend == "database": 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 try: 

314 

315 def _db_remove(): 

316 db_session = next(get_db()) 

317 try: 

318 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete() 

319 db_session.commit() 

320 except Exception as ex: 

321 db_session.rollback() 

322 raise ex 

323 finally: 

324 db_session.close() 

325 

326 await asyncio.to_thread(_db_remove) 

327 except Exception as e: 

328 logger.error(f"Database error removing session {session_id}: {e}") 

329 

330 logger.info(f"Removed session: {session_id}") 

331 

332 async def broadcast(self, session_id: str, message: dict) -> None: 

333 """Broadcast a session_id and message to a channel. 

334 

335 Args: 

336 session_id: Session ID 

337 message: Message to broadcast 

338 """ 

339 # Skip for none and memory backend 

340 if self._backend == "none": 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true

341 return 

342 

343 if self._backend == "memory": 343 ↛ 351line 343 didn't jump to line 351 because the condition on line 343 was always true

344 if isinstance(message, (dict, list)): 344 ↛ 347line 344 didn't jump to line 347 because the condition on line 344 was always true

345 msg_json = json.dumps(message) 

346 else: 

347 msg_json = json.dumps(str(message)) 

348 

349 self._session_message = {"session_id": session_id, "message": msg_json} 

350 

351 elif self._backend == "redis": 

352 try: 

353 if isinstance(message, (dict, list)): 

354 msg_json = json.dumps(message) 

355 else: 

356 msg_json = json.dumps(str(message)) 

357 

358 await self._redis.publish(session_id, json.dumps({"type": "message", "message": msg_json, "timestamp": time.time()})) 

359 except Exception as e: 

360 logger.error(f"Redis error during broadcast: {e}") 

361 elif self._backend == "database": 

362 try: 

363 if isinstance(message, (dict, list)): 

364 msg_json = json.dumps(message) 

365 else: 

366 msg_json = json.dumps(str(message)) 

367 

368 def _db_add(): 

369 db_session = next(get_db()) 

370 try: 

371 message_record = SessionMessageRecord(session_id=session_id, message=msg_json) 

372 db_session.add(message_record) 

373 db_session.commit() 

374 except Exception as ex: 

375 db_session.rollback() 

376 raise ex 

377 finally: 

378 db_session.close() 

379 

380 await asyncio.to_thread(_db_add) 

381 except Exception as e: 

382 logger.error(f"Database error during broadcast: {e}") 

383 

384 def get_session_sync(self, session_id: str) -> Any: 

385 """Get session synchronously (not checking shared backend). 

386 

387 This is a non-blocking method for handlers that need quick access. 

388 It only checks the local cache, not the shared backend. 

389 

390 Args: 

391 session_id: Session identifier 

392 

393 Returns: 

394 Transport object or None if not found 

395 """ 

396 # Skip for none backend 

397 if self._backend == "none": 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true

398 return None 

399 

400 return self._sessions.get(session_id) 

401 

402 async def respond( 

403 self, 

404 server_id: Optional[str], 

405 user: json, 

406 session_id: str, 

407 base_url: str, 

408 ) -> None: 

409 """Respond to broadcast message is transport relevant to session_id is found locally 

410 

411 Args: 

412 server_id: Server ID 

413 session_id: Session ID 

414 user: User information 

415 base_url: Base URL for the FastAPI request 

416 

417 """ 

418 

419 if self._backend == "none": 419 ↛ 420line 419 didn't jump to line 420 because the condition on line 419 was never true

420 pass 

421 

422 elif self._backend == "memory": 422 ↛ 429line 422 didn't jump to line 429 because the condition on line 422 was always true

423 # if self._session_message: 

424 transport = self.get_session_sync(session_id) 

425 if transport: 425 ↛ exitline 425 didn't return from function 'respond' because the condition on line 425 was always true

426 message = json.loads(self._session_message.get("message")) 

427 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) 

428 

429 elif self._backend == "redis": 

430 await self._pubsub.subscribe(session_id) 

431 

432 try: 

433 async for msg in self._pubsub.listen(): 

434 if msg["type"] != "message": 

435 continue 

436 data = json.loads(msg["data"]) 

437 message = data.get("message", {}) 

438 if isinstance(message, str): 

439 message = json.loads(message) 

440 transport = self.get_session_sync(session_id) 

441 if transport: 

442 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) 

443 except asyncio.CancelledError: 

444 logger.info(f"PubSub listener for session {session_id} cancelled") 

445 finally: 

446 await self._pubsub.unsubscribe(session_id) 

447 logger.info(f"Cleaned up pubsub for session {session_id}") 

448 

449 elif self._backend == "database": 

450 

451 def _db_read_session(session_id): 

452 db_session = next(get_db()) 

453 try: 

454 # Delete sessions that haven't been accessed for TTL seconds 

455 result = db_session.query(SessionRecord).filter_by(session_id=session_id).first() 

456 return result 

457 except Exception as ex: 

458 db_session.rollback() 

459 raise ex 

460 finally: 

461 db_session.close() 

462 

463 def _db_read(session_id): 

464 db_session = next(get_db()) 

465 try: 

466 # Delete sessions that haven't been accessed for TTL seconds 

467 result = db_session.query(SessionMessageRecord).filter_by(session_id=session_id).first() 

468 return result 

469 except Exception as ex: 

470 db_session.rollback() 

471 raise ex 

472 finally: 

473 db_session.close() 

474 

475 def _db_remove(session_id, message): 

476 db_session = next(get_db()) 

477 try: 

478 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete() 

479 db_session.commit() 

480 logger.info("Removed message from mcp_messages table") 

481 except Exception as ex: 

482 db_session.rollback() 

483 raise ex 

484 finally: 

485 db_session.close() 

486 

487 async def message_check_loop(session_id): 

488 while True: 

489 record = await asyncio.to_thread(_db_read, session_id) 

490 

491 if record: 

492 message = json.loads(record.message) 

493 transport = self.get_session_sync(session_id) 

494 if transport: 

495 logger.info("Ready to respond") 

496 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) 

497 

498 await asyncio.to_thread(_db_remove, session_id, record.message) 

499 

500 session_exists = await asyncio.to_thread(_db_read_session, session_id) 

501 if not session_exists: 

502 break 

503 

504 await asyncio.sleep(0.1) 

505 

506 asyncio.create_task(message_check_loop(session_id)) 

507 

508 async def _refresh_redis_sessions(self) -> None: 

509 """Refresh TTLs for Redis sessions and clean up disconnected sessions.""" 

510 try: 

511 # Check all local sessions 

512 local_transports = {} 

513 async with self._lock: 

514 local_transports = self._sessions.copy() 

515 

516 for session_id, transport in local_transports.items(): 

517 try: 

518 if await transport.is_connected(): 

519 # Refresh TTL in Redis 

520 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl) 

521 else: 

522 # Remove disconnected session 

523 await self.remove_session(session_id) 

524 except Exception as e: 

525 logger.error(f"Error refreshing session {session_id}: {e}") 

526 

527 except Exception as e: 

528 logger.error(f"Error in Redis session refresh: {e}") 

529 

530 async def _db_cleanup_task(self) -> None: 

531 """Periodically clean up expired database sessions.""" 

532 logger.info("Starting database cleanup task") 

533 while True: 

534 try: 

535 # Clean up expired sessions every 5 minutes 

536 def _db_cleanup(): 

537 db_session = next(get_db()) 

538 try: 

539 # Delete sessions that haven't been accessed for TTL seconds 

540 expiry_time = func.now() - func.make_interval(seconds=self._session_ttl) # pylint: disable=not-callable 

541 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete() 

542 db_session.commit() 

543 return result 

544 except Exception as ex: 

545 db_session.rollback() 

546 raise ex 

547 finally: 

548 db_session.close() 

549 

550 deleted = await asyncio.to_thread(_db_cleanup) 

551 if deleted > 0: 

552 logger.info(f"Cleaned up {deleted} expired database sessions") 

553 

554 # Check local sessions against database 

555 local_transports = {} 

556 async with self._lock: 

557 local_transports = self._sessions.copy() 

558 

559 for session_id, transport in local_transports.items(): 

560 try: 

561 if not await transport.is_connected(): 

562 await self.remove_session(session_id) 

563 continue 

564 

565 # Refresh session in database 

566 def _refresh_session(): 

567 db_session = next(get_db()) 

568 try: 

569 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

570 

571 if session: 

572 # Update last_accessed 

573 session.last_accessed = func.now() # pylint: disable=not-callable 

574 db_session.commit() 

575 return True 

576 return False 

577 except Exception as ex: 

578 db_session.rollback() 

579 raise ex 

580 finally: 

581 db_session.close() 

582 

583 session_exists = await asyncio.to_thread(_refresh_session) 

584 if not session_exists: 

585 # Session no longer in database, remove locally 

586 await self.remove_session(session_id) 

587 

588 except Exception as e: 

589 logger.error(f"Error checking session {session_id}: {e}") 

590 

591 await asyncio.sleep(300) # Run every 5 minutes 

592 

593 except asyncio.CancelledError: 

594 logger.info("Database cleanup task cancelled") 

595 break 

596 except Exception as e: 

597 logger.error(f"Error in database cleanup task: {e}") 

598 await asyncio.sleep(600) # Sleep longer on error 

599 

600 async def _memory_cleanup_task(self) -> None: 

601 """Periodically clean up disconnected sessions.""" 

602 logger.info("Starting memory cleanup task") 

603 while True: 

604 try: 

605 # Check all local sessions 

606 local_transports = {} 

607 async with self._lock: 

608 local_transports = self._sessions.copy() 

609 

610 for session_id, transport in local_transports.items(): 610 ↛ 611line 610 didn't jump to line 611 because the loop on line 610 never started

611 try: 

612 if not await transport.is_connected(): 

613 await self.remove_session(session_id) 

614 except Exception as e: 

615 logger.error(f"Error checking session {session_id}: {e}") 

616 await self.remove_session(session_id) 

617 

618 await asyncio.sleep(60) # Run every minute 

619 

620 except asyncio.CancelledError: 

621 logger.info("Memory cleanup task cancelled") 

622 break 

623 except Exception as e: 

624 logger.error(f"Error in memory cleanup task: {e}") 

625 await asyncio.sleep(300) # Sleep longer on error 

626 

627 # Handle initialize logic 

628 async def handle_initialize_logic(self, body: dict) -> InitializeResult: 

629 """ 

630 Validates the protocol version from the request body and returns an InitializeResult with server capabilities and info. 

631 

632 Args: 

633 body (dict): The incoming request body. 

634 

635 Raises: 

636 HTTPException: If the protocol version is missing or unsupported. 

637 

638 Returns: 

639 InitializeResult: Initialization result with protocol version, capabilities, and server info. 

640 """ 

641 protocol_version = body.get("protocol_version") or body.get("protocolVersion") 

642 # body.get("capabilities", {}) 

643 # body.get("client_info") or body.get("clientInfo", {}) 

644 

645 if not protocol_version: 

646 raise HTTPException( 

647 status_code=status.HTTP_400_BAD_REQUEST, 

648 detail="Missing protocol version", 

649 headers={"MCP-Error-Code": "-32002"}, 

650 ) 

651 

652 if protocol_version != settings.protocol_version: 

653 raise HTTPException( 

654 status_code=status.HTTP_400_BAD_REQUEST, 

655 detail=f"Unsupported protocol version: {protocol_version}", 

656 headers={"MCP-Error-Code": "-32003"}, 

657 ) 

658 

659 return InitializeResult( 

660 protocolVersion=settings.protocol_version, 

661 capabilities=ServerCapabilities( 

662 prompts={"listChanged": True}, 

663 resources={"subscribe": True, "listChanged": True}, 

664 tools={"listChanged": True}, 

665 logging={}, 

666 roots={"listChanged": True}, 

667 sampling={}, 

668 ), 

669 serverInfo=Implementation(name=settings.app_name, version="1.0.0"), 

670 instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."), 

671 ) 

672 

673 async def generate_response(self, message: json, transport: SSETransport, server_id: Optional[str], user: dict, base_url: str): 

674 """ 

675 Generates response according to SSE specifications 

676 

677 Args: 

678 message: Message JSON 

679 transport: Transport where message should be responded in 

680 server_id: Server ID 

681 user: User information 

682 base_url: Base URL for the FastAPI request 

683 

684 """ 

685 result = {} 

686 

687 if "method" in message and "id" in message: 

688 method = message["method"] 

689 params = message.get("params", {}) 

690 req_id = message["id"] 

691 db = next(get_db()) 

692 if method == "initialize": 

693 init_result = await self.handle_initialize_logic(params) 

694 response = { 

695 "jsonrpc": "2.0", 

696 "result": init_result.model_dump(by_alias=True, exclude_none=True), 

697 "id": req_id, 

698 } 

699 await transport.send_message(response) 

700 await transport.send_message( 

701 { 

702 "jsonrpc": "2.0", 

703 "method": "notifications/initialized", 

704 "params": {}, 

705 } 

706 ) 

707 notifications = [ 

708 "tools/list_changed", 

709 "resources/list_changed", 

710 "prompts/list_changed", 

711 ] 

712 for notification in notifications: 

713 await transport.send_message( 

714 { 

715 "jsonrpc": "2.0", 

716 "method": f"notifications/{notification}", 

717 "params": {}, 

718 } 

719 ) 

720 elif method == "tools/list": 

721 if server_id: 

722 tools = await tool_service.list_server_tools(db, server_id=server_id) 

723 else: 

724 tools = await tool_service.list_tools(db) 

725 result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]} 

726 elif method == "resources/list": 

727 if server_id: 

728 resources = await resource_service.list_server_resources(db, server_id=server_id) 

729 else: 

730 resources = await resource_service.list_resources(db) 

731 result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]} 

732 elif method == "prompts/list": 

733 if server_id: 

734 prompts = await prompt_service.list_server_prompts(db, server_id=server_id) 

735 else: 

736 prompts = await prompt_service.list_prompts(db) 

737 result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]} 

738 elif method == "ping": 

739 result = {} 

740 elif method == "tools/call": 

741 rpc_input = { 

742 "jsonrpc": "2.0", 

743 "method": message["params"]["name"], 

744 "params": message["params"]["arguments"], 

745 "id": 1, 

746 } 

747 headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"} 

748 rpc_url = base_url + "/rpc" 

749 async with httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) as client: 

750 rpc_response = await client.post( 

751 url=rpc_url, 

752 json=rpc_input, 

753 headers=headers, 

754 ) 

755 result = rpc_response.json() 

756 else: 

757 result = {} 

758 

759 response = {"jsonrpc": "2.0", "result": result, "id": req_id} 

760 logging.info(f"Sending sse message:{response}") 

761 await transport.send_message(response)