Coverage for mcpgateway/services/server_service.py: 71%

234 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 12:53 +0100

1# -*- coding: utf-8 -*- 

2""" 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8MCP Gateway Server Service 

9 

10This module implements server management for the MCP Servers Catalog. 

11It handles server registration, listing, retrieval, updates, activation toggling, and deletion. 

12It also publishes event notifications for server changes. 

13""" 

14 

15import asyncio 

16import logging 

17from datetime import datetime 

18from typing import Any, AsyncGenerator, Dict, List, Optional 

19 

20import httpx 

21from sqlalchemy import delete, func, not_, select 

22from sqlalchemy.exc import IntegrityError 

23from sqlalchemy.orm import Session 

24 

25from mcpgateway.config import settings 

26from mcpgateway.db import Prompt as DbPrompt 

27from mcpgateway.db import Resource as DbResource 

28from mcpgateway.db import Server as DbServer 

29from mcpgateway.db import ServerMetric 

30from mcpgateway.db import Tool as DbTool 

31from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class ServerError(Exception): 

37 """Base class for server-related errors.""" 

38 

39 

40class ServerNotFoundError(ServerError): 

41 """Raised when a requested server is not found.""" 

42 

43 

44class ServerNameConflictError(ServerError): 

45 """Raised when a server name conflicts with an existing one.""" 

46 

47 def __init__(self, name: str, is_active: bool = True, server_id: Optional[int] = None): 

48 self.name = name 

49 self.is_active = is_active 

50 self.server_id = server_id 

51 message = f"Server already exists with name: {name}" 

52 if not is_active: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 message += f" (currently inactive, ID: {server_id})" 

54 super().__init__(message) 

55 

56 

57class ServerService: 

58 """Service for managing MCP Servers in the catalog. 

59 

60 Provides methods to create, list, retrieve, update, toggle status, and delete server records. 

61 Also supports event notifications for changes in server data. 

62 """ 

63 

64 def __init__(self) -> None: 

65 self._event_subscribers: List[asyncio.Queue] = [] 

66 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

67 

68 async def initialize(self) -> None: 

69 """Initialize the server service.""" 

70 logger.info("Initializing server service") 

71 

72 async def shutdown(self) -> None: 

73 """Shutdown the server service.""" 

74 await self._http_client.aclose() 

75 logger.info("Server service shutdown complete") 

76 

77 def _convert_server_to_read(self, server: DbServer) -> ServerRead: 

78 """ 

79 Converts a DbServer instance into a ServerRead model, including aggregated metrics. 

80 

81 Args: 

82 server (DbServer): The ORM instance of the server. 

83 

84 Returns: 

85 ServerRead: The Pydantic model representing the server, including aggregated metrics. 

86 """ 

87 server_dict = server.__dict__.copy() 

88 server_dict.pop("_sa_instance_state", None) 

89 # Compute aggregated metrics from server.metrics; default to 0/None when no records exist. 

90 total = len(server.metrics) if hasattr(server, "metrics") else 0 

91 successful = sum(1 for m in server.metrics if m.is_success) if total > 0 else 0 

92 failed = sum(1 for m in server.metrics if not m.is_success) if total > 0 else 0 

93 failure_rate = (failed / total) if total > 0 else 0.0 

94 min_rt = min((m.response_time for m in server.metrics), default=None) if total > 0 else None 

95 max_rt = max((m.response_time for m in server.metrics), default=None) if total > 0 else None 

96 avg_rt = (sum(m.response_time for m in server.metrics) / total) if total > 0 else None 

97 last_time = max((m.timestamp for m in server.metrics), default=None) if total > 0 else None 

98 

99 server_dict["metrics"] = { 

100 "total_executions": total, 

101 "successful_executions": successful, 

102 "failed_executions": failed, 

103 "failure_rate": failure_rate, 

104 "min_response_time": min_rt, 

105 "max_response_time": max_rt, 

106 "avg_response_time": avg_rt, 

107 "last_execution_time": last_time, 

108 } 

109 # Also update associated IDs (if not already done) 

110 server_dict["associated_tools"] = [tool.id for tool in server.tools] if server.tools else [] 

111 server_dict["associated_resources"] = [res.id for res in server.resources] if server.resources else [] 

112 server_dict["associated_prompts"] = [prompt.id for prompt in server.prompts] if server.prompts else [] 

113 return ServerRead.model_validate(server_dict) 

114 

115 def _assemble_associated_items( 

116 self, 

117 tools: Optional[List[str]], 

118 resources: Optional[List[str]], 

119 prompts: Optional[List[str]], 

120 ) -> Dict[str, Any]: 

121 """ 

122 Assemble the associated items dictionary from the separate fields. 

123 

124 Args: 

125 tools: List of tool IDs. 

126 resources: List of resource IDs. 

127 prompts: List of prompt IDs. 

128 

129 Returns: 

130 A dictionary with keys "tools", "resources", and "prompts". 

131 """ 

132 return { 

133 "tools": tools or [], 

134 "resources": resources or [], 

135 "prompts": prompts or [], 

136 } 

137 

138 async def register_server(self, db: Session, server_in: ServerCreate) -> ServerRead: 

139 """ 

140 Register a new server in the catalog and validate that all associated items exist. 

141 

142 This function performs the following steps: 

143 1. Checks if a server with the same name already exists. 

144 2. Creates a new server record. 

145 3. For each ID provided in associated_tools, associated_resources, and associated_prompts, 

146 verifies that the corresponding item exists. If an item does not exist, an error is raised. 

147 4. Associates the verified items to the new server. 

148 5. Commits the transaction, refreshes the ORM instance, and forces the loading of relationship data. 

149 6. Constructs a response dictionary that includes lists of associated item IDs. 

150 7. Notifies subscribers of the addition and returns the validated response. 

151 

152 Args: 

153 db (Session): The SQLAlchemy database session. 

154 server_in (ServerCreate): The server creation schema containing server details and lists of 

155 associated tool, resource, and prompt IDs (as strings). 

156 

157 Returns: 

158 ServerRead: The newly created server, with associated item IDs. 

159 

160 Raises: 

161 ServerNameConflictError: If a server with the same name already exists. 

162 ServerError: If any associated tool, resource, or prompt does not exist, or if any other 

163 registration error occurs. 

164 """ 

165 try: 

166 # Check for an existing server with the same name. 

167 existing = db.execute(select(DbServer).where(DbServer.name == server_in.name)).scalar_one_or_none() 

168 if existing: 

169 raise ServerNameConflictError(server_in.name, is_active=existing.is_active, server_id=existing.id) 

170 

171 # Create the new server record. 

172 db_server = DbServer( 

173 name=server_in.name, 

174 description=server_in.description, 

175 icon=server_in.icon, 

176 is_active=True, 

177 ) 

178 db.add(db_server) 

179 

180 # Associate tools, verifying each exists. 

181 if server_in.associated_tools: 181 ↛ 191line 181 didn't jump to line 191 because the condition on line 181 was always true

182 for tool_id in server_in.associated_tools: 

183 if tool_id.strip() == "": 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 continue 

185 tool_obj = db.get(DbTool, int(tool_id)) 

186 if not tool_obj: 

187 raise ServerError(f"Tool with id {tool_id} does not exist.") 

188 db_server.tools.append(tool_obj) 

189 

190 # Associate resources, verifying each exists. 

191 if server_in.associated_resources: 191 ↛ 201line 191 didn't jump to line 201 because the condition on line 191 was always true

192 for resource_id in server_in.associated_resources: 

193 if resource_id.strip() == "": 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 continue 

195 resource_obj = db.get(DbResource, int(resource_id)) 

196 if not resource_obj: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 raise ServerError(f"Resource with id {resource_id} does not exist.") 

198 db_server.resources.append(resource_obj) 

199 

200 # Associate prompts, verifying each exists. 

201 if server_in.associated_prompts: 201 ↛ 211line 201 didn't jump to line 211 because the condition on line 201 was always true

202 for prompt_id in server_in.associated_prompts: 

203 if prompt_id.strip() == "": 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 continue 

205 prompt_obj = db.get(DbPrompt, int(prompt_id)) 

206 if not prompt_obj: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 raise ServerError(f"Prompt with id {prompt_id} does not exist.") 

208 db_server.prompts.append(prompt_obj) 

209 

210 # Commit the new record and refresh. 

211 db.commit() 

212 db.refresh(db_server) 

213 # Force load the relationship attributes. 

214 _ = db_server.tools, db_server.resources, db_server.prompts 

215 

216 # Assemble response data with associated item IDs. 

217 server_data = { 

218 "id": db_server.id, 

219 "name": db_server.name, 

220 "description": db_server.description, 

221 "icon": db_server.icon, 

222 "created_at": db_server.created_at, 

223 "updated_at": db_server.updated_at, 

224 "is_active": db_server.is_active, 

225 "associated_tools": [str(tool.id) for tool in db_server.tools], 

226 "associated_resources": [str(resource.id) for resource in db_server.resources], 

227 "associated_prompts": [str(prompt.id) for prompt in db_server.prompts], 

228 } 

229 logger.debug(f"Server Data: {server_data}") 

230 await self._notify_server_added(db_server) 

231 logger.info(f"Registered server: {server_in.name}") 

232 return self._convert_server_to_read(db_server) 

233 except IntegrityError: 

234 db.rollback() 

235 raise ServerError(f"Server already exists: {server_in.name}") 

236 except Exception as e: 

237 db.rollback() 

238 raise ServerError(f"Failed to register server: {str(e)}") 

239 

240 async def list_servers(self, db: Session, include_inactive: bool = False) -> List[ServerRead]: 

241 """List all registered servers. 

242 

243 Args: 

244 db: Database session. 

245 include_inactive: Whether to include inactive servers. 

246 

247 Returns: 

248 A list of ServerRead objects. 

249 """ 

250 query = select(DbServer) 

251 if not include_inactive: 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true

252 query = query.where(DbServer.is_active) 

253 servers = db.execute(query).scalars().all() 

254 return [self._convert_server_to_read(s) for s in servers] 

255 

256 async def get_server(self, db: Session, server_id: int) -> ServerRead: 

257 """Retrieve server details by ID. 

258 

259 Args: 

260 db: Database session. 

261 server_id: The unique identifier of the server. 

262 

263 Returns: 

264 The corresponding ServerRead object. 

265 

266 Raises: 

267 ServerNotFoundError: If no server with the given ID exists. 

268 """ 

269 server = db.get(DbServer, server_id) 

270 if not server: 

271 raise ServerNotFoundError(f"Server not found: {server_id}") 

272 server_data = { 

273 "id": server.id, 

274 "name": server.name, 

275 "description": server.description, 

276 "icon": server.icon, 

277 "created_at": server.created_at, 

278 "updated_at": server.updated_at, 

279 "is_active": server.is_active, 

280 "associated_tools": [tool.id for tool in server.tools], 

281 "associated_resources": [res.id for res in server.resources], 

282 "associated_prompts": [prompt.id for prompt in server.prompts], 

283 } 

284 logger.debug(f"Server Data: {server_data}") 

285 return self._convert_server_to_read(server) 

286 

287 async def update_server(self, db: Session, server_id: int, server_update: ServerUpdate) -> ServerRead: 

288 """Update an existing server. 

289 

290 Args: 

291 db: Database session. 

292 server_id: The unique identifier of the server. 

293 server_update: Server update schema with new data. 

294 

295 Returns: 

296 The updated ServerRead object. 

297 

298 Raises: 

299 ServerNotFoundError: If the server is not found. 

300 ServerNameConflictError: If a new name conflicts with an existing server. 

301 ServerError: For other update errors. 

302 """ 

303 try: 

304 server = db.get(DbServer, server_id) 

305 if not server: 

306 raise ServerNotFoundError(f"Server not found: {server_id}") 

307 

308 # Check for name conflict if name is being changed 

309 if server_update.name and server_update.name != server.name: 309 ↛ 319line 309 didn't jump to line 319 because the condition on line 309 was always true

310 conflict = db.execute(select(DbServer).where(DbServer.name == server_update.name).where(DbServer.id != server_id)).scalar_one_or_none() 

311 if conflict: 

312 raise ServerNameConflictError( 

313 server_update.name, 

314 is_active=conflict.is_active, 

315 server_id=conflict.id, 

316 ) 

317 

318 # Update simple fields 

319 if server_update.name is not None: 319 ↛ 321line 319 didn't jump to line 321 because the condition on line 319 was always true

320 server.name = server_update.name 

321 if server_update.description is not None: 321 ↛ 323line 321 didn't jump to line 323 because the condition on line 321 was always true

322 server.description = server_update.description 

323 if server_update.icon is not None: 323 ↛ 327line 323 didn't jump to line 327 because the condition on line 323 was always true

324 server.icon = server_update.icon 

325 

326 # Update associated tools if provided 

327 if server_update.associated_tools is not None: 327 ↛ 335line 327 didn't jump to line 335 because the condition on line 327 was always true

328 server.tools = [] 

329 for tool_id in server_update.associated_tools: 

330 tool_obj = db.get(DbTool, int(tool_id)) 

331 if tool_obj: 331 ↛ 329line 331 didn't jump to line 329 because the condition on line 331 was always true

332 server.tools.append(tool_obj) 

333 

334 # Update associated resources if provided 

335 if server_update.associated_resources is not None: 335 ↛ 343line 335 didn't jump to line 343 because the condition on line 335 was always true

336 server.resources = [] 

337 for resource_id in server_update.associated_resources: 

338 resource_obj = db.get(DbResource, int(resource_id)) 

339 if resource_obj: 339 ↛ 337line 339 didn't jump to line 337 because the condition on line 339 was always true

340 server.resources.append(resource_obj) 

341 

342 # Update associated prompts if provided 

343 if server_update.associated_prompts is not None: 343 ↛ 350line 343 didn't jump to line 350 because the condition on line 343 was always true

344 server.prompts = [] 

345 for prompt_id in server_update.associated_prompts: 

346 prompt_obj = db.get(DbPrompt, int(prompt_id)) 

347 if prompt_obj: 347 ↛ 345line 347 didn't jump to line 345 because the condition on line 347 was always true

348 server.prompts.append(prompt_obj) 

349 

350 server.updated_at = datetime.utcnow() 

351 db.commit() 

352 db.refresh(server) 

353 # Force loading relationships 

354 _ = server.tools, server.resources, server.prompts 

355 

356 await self._notify_server_updated(server) 

357 logger.info(f"Updated server: {server.name}") 

358 

359 # Build a dictionary with associated IDs 

360 server_data = { 

361 "id": server.id, 

362 "name": server.name, 

363 "description": server.description, 

364 "icon": server.icon, 

365 "created_at": server.created_at, 

366 "updated_at": server.updated_at, 

367 "is_active": server.is_active, 

368 "associated_tools": [tool.id for tool in server.tools], 

369 "associated_resources": [res.id for res in server.resources], 

370 "associated_prompts": [prompt.id for prompt in server.prompts], 

371 } 

372 logger.debug(f"Server Data: {server_data}") 

373 return self._convert_server_to_read(server) 

374 except Exception as e: 

375 db.rollback() 

376 raise ServerError(f"Failed to update server: {str(e)}") 

377 

378 async def toggle_server_status(self, db: Session, server_id: int, activate: bool) -> ServerRead: 

379 """Toggle the activation status of a server. 

380 

381 Args: 

382 db: Database session. 

383 server_id: The unique identifier of the server. 

384 activate: True to activate, False to deactivate. 

385 

386 Returns: 

387 The updated ServerRead object. 

388 

389 Raises: 

390 ServerNotFoundError: If the server is not found. 

391 ServerError: For other errors. 

392 """ 

393 try: 

394 server = db.get(DbServer, server_id) 

395 if not server: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 raise ServerNotFoundError(f"Server not found: {server_id}") 

397 

398 if server.is_active != activate: 398 ↛ 409line 398 didn't jump to line 409 because the condition on line 398 was always true

399 server.is_active = activate 

400 server.updated_at = datetime.utcnow() 

401 db.commit() 

402 db.refresh(server) 

403 if activate: 403 ↛ 404line 403 didn't jump to line 404 because the condition on line 403 was never true

404 await self._notify_server_activated(server) 

405 else: 

406 await self._notify_server_deactivated(server) 

407 logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") 

408 

409 server_data = { 

410 "id": server.id, 

411 "name": server.name, 

412 "description": server.description, 

413 "icon": server.icon, 

414 "created_at": server.created_at, 

415 "updated_at": server.updated_at, 

416 "is_active": server.is_active, 

417 "associated_tools": [tool.id for tool in server.tools], 

418 "associated_resources": [res.id for res in server.resources], 

419 "associated_prompts": [prompt.id for prompt in server.prompts], 

420 } 

421 logger.debug(f"Server Data: {server_data}") 

422 return self._convert_server_to_read(server) 

423 except Exception as e: 

424 db.rollback() 

425 raise ServerError(f"Failed to toggle server status: {str(e)}") 

426 

427 async def delete_server(self, db: Session, server_id: int) -> None: 

428 """Permanently delete a server. 

429 

430 Args: 

431 db: Database session. 

432 server_id: The unique identifier of the server. 

433 

434 Raises: 

435 ServerNotFoundError: If the server is not found. 

436 ServerError: For other deletion errors. 

437 """ 

438 try: 

439 server = db.get(DbServer, server_id) 

440 if not server: 

441 raise ServerNotFoundError(f"Server not found: {server_id}") 

442 

443 server_info = {"id": server.id, "name": server.name} 

444 db.delete(server) 

445 db.commit() 

446 

447 await self._notify_server_deleted(server_info) 

448 logger.info(f"Deleted server: {server_info['name']}") 

449 except Exception as e: 

450 db.rollback() 

451 raise ServerError(f"Failed to delete server: {str(e)}") 

452 

453 async def _publish_event(self, event: Dict[str, Any]) -> None: 

454 """ 

455 Publish an event to all subscribed queues. 

456 

457 Args: 

458 event: Event to publish 

459 """ 

460 for queue in self._event_subscribers: 

461 await queue.put(event) 

462 

463 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

464 """Subscribe to server events. 

465 

466 Yields: 

467 Server event messages. 

468 """ 

469 queue: asyncio.Queue = asyncio.Queue() 

470 self._event_subscribers.append(queue) 

471 try: 

472 while True: 

473 event = await queue.get() 

474 yield event 

475 finally: 

476 self._event_subscribers.remove(queue) 

477 

478 async def _notify_server_added(self, server: DbServer) -> None: 

479 """ 

480 Notify subscribers that a new server has been added. 

481 

482 Args: 

483 server: Server to add 

484 """ 

485 associated_tools = [tool.id for tool in server.tools] if server.tools else [] 

486 associated_resources = [res.id for res in server.resources] if server.resources else [] 

487 associated_prompts = [prompt.id for prompt in server.prompts] if server.prompts else [] 

488 event = { 

489 "type": "server_added", 

490 "data": { 

491 "id": server.id, 

492 "name": server.name, 

493 "description": server.description, 

494 "icon": server.icon, 

495 "associated_tools": associated_tools, 

496 "associated_resources": associated_resources, 

497 "associated_prompts": associated_prompts, 

498 "is_active": server.is_active, 

499 }, 

500 "timestamp": datetime.utcnow().isoformat(), 

501 } 

502 await self._publish_event(event) 

503 

504 async def _notify_server_updated(self, server: DbServer) -> None: 

505 """ 

506 Notify subscribers that a server has been updated. 

507 

508 Args: 

509 server: Server to update 

510 """ 

511 associated_tools = [tool.id for tool in server.tools] if server.tools else [] 

512 associated_resources = [res.id for res in server.resources] if server.resources else [] 

513 associated_prompts = [prompt.id for prompt in server.prompts] if server.prompts else [] 

514 event = { 

515 "type": "server_updated", 

516 "data": { 

517 "id": server.id, 

518 "name": server.name, 

519 "description": server.description, 

520 "icon": server.icon, 

521 "associated_tools": associated_tools, 

522 "associated_resources": associated_resources, 

523 "associated_prompts": associated_prompts, 

524 "is_active": server.is_active, 

525 }, 

526 "timestamp": datetime.utcnow().isoformat(), 

527 } 

528 await self._publish_event(event) 

529 

530 async def _notify_server_activated(self, server: DbServer) -> None: 

531 """ 

532 Notify subscribers that a server has been activated. 

533 

534 Args: 

535 server: Server to activate 

536 """ 

537 event = { 

538 "type": "server_activated", 

539 "data": { 

540 "id": server.id, 

541 "name": server.name, 

542 "is_active": True, 

543 }, 

544 "timestamp": datetime.utcnow().isoformat(), 

545 } 

546 await self._publish_event(event) 

547 

548 async def _notify_server_deactivated(self, server: DbServer) -> None: 

549 """ 

550 Notify subscribers that a server has been deactivated. 

551 

552 Args: 

553 server: Server to deactivate 

554 """ 

555 event = { 

556 "type": "server_deactivated", 

557 "data": { 

558 "id": server.id, 

559 "name": server.name, 

560 "is_active": False, 

561 }, 

562 "timestamp": datetime.utcnow().isoformat(), 

563 } 

564 await self._publish_event(event) 

565 

566 async def _notify_server_deleted(self, server_info: Dict[str, Any]) -> None: 

567 """ 

568 Notify subscribers that a server has been deleted. 

569 

570 Args: 

571 server_info: Dictionary on server to be deleted 

572 """ 

573 event = { 

574 "type": "server_deleted", 

575 "data": server_info, 

576 "timestamp": datetime.utcnow().isoformat(), 

577 } 

578 await self._publish_event(event) 

579 

580 # --- Metrics --- 

581 async def aggregate_metrics(self, db: Session) -> ServerMetrics: 

582 """ 

583 Aggregate metrics for all server invocations across all servers. 

584 

585 Args: 

586 db: Database session 

587 

588 Returns: 

589 ServerMetrics: Aggregated metrics computed from all ServerMetric records. 

590 """ 

591 total_executions = db.execute(select(func.count()).select_from(ServerMetric)).scalar() or 0 # pylint: disable=not-callable 

592 

593 successful_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success)).scalar() or 0 # pylint: disable=not-callable 

594 

595 failed_executions = db.execute(select(func.count()).select_from(ServerMetric).where(not_(ServerMetric.is_success))).scalar() or 0 # pylint: disable=not-callable 

596 

597 min_response_time = db.execute(select(func.min(ServerMetric.response_time))).scalar() 

598 

599 max_response_time = db.execute(select(func.max(ServerMetric.response_time))).scalar() 

600 

601 avg_response_time = db.execute(select(func.avg(ServerMetric.response_time))).scalar() 

602 

603 last_execution_time = db.execute(select(func.max(ServerMetric.timestamp))).scalar() 

604 

605 return ServerMetrics( 

606 total_executions=total_executions, 

607 successful_executions=successful_executions, 

608 failed_executions=failed_executions, 

609 failure_rate=(failed_executions / total_executions) if total_executions > 0 else 0.0, 

610 min_response_time=min_response_time, 

611 max_response_time=max_response_time, 

612 avg_response_time=avg_response_time, 

613 last_execution_time=last_execution_time, 

614 ) 

615 

616 async def reset_metrics(self, db: Session) -> None: 

617 """ 

618 Reset all server metrics by deleting all records from the server metrics table. 

619 

620 Args: 

621 db: Database session 

622 """ 

623 db.execute(delete(ServerMetric)) 

624 db.commit()