Coverage for mcpgateway/services/tool_service.py: 55%

321 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 12:53 +0100

1# -*- coding: utf-8 -*- 

2"""Tool Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements tool management and invocation according to the MCP specification. 

9It handles: 

10- Tool registration and validation 

11- Tool invocation with schema validation 

12- Tool federation across gateways 

13- Event notifications for tool changes 

14- Active/inactive tool management 

15""" 

16 

17import asyncio 

18import base64 

19import json 

20import logging 

21import time 

22from datetime import datetime 

23from typing import Any, AsyncGenerator, Dict, List, Optional 

24 

25import httpx 

26from mcp import ClientSession 

27from mcp.client.sse import sse_client 

28from mcp.client.streamable_http import streamablehttp_client 

29from sqlalchemy import delete, func, not_, select 

30from sqlalchemy.exc import IntegrityError 

31from sqlalchemy.orm import Session 

32 

33from mcpgateway.config import settings 

34from mcpgateway.db import Gateway as DbGateway 

35from mcpgateway.db import Tool as DbTool 

36from mcpgateway.db import ToolMetric, server_tool_association 

37from mcpgateway.schemas import ( 

38 ToolCreate, 

39 ToolRead, 

40 ToolUpdate, 

41) 

42from mcpgateway.types import TextContent, ToolResult 

43from mcpgateway.utils.services_auth import decode_auth 

44 

45from ..config import extract_using_jq 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50class ToolError(Exception): 

51 """Base class for tool-related errors.""" 

52 

53 

54class ToolNotFoundError(ToolError): 

55 """Raised when a requested tool is not found.""" 

56 

57 

58class ToolNameConflictError(ToolError): 

59 """Raised when a tool name conflicts with existing (active or inactive) tool.""" 

60 

61 def __init__(self, name: str, is_active: bool = True, tool_id: Optional[int] = None): 

62 """Initialize the error with tool information. 

63 

64 Args: 

65 name: The conflicting tool name. 

66 is_active: Whether the existing tool is active. 

67 tool_id: ID of the existing tool if available. 

68 """ 

69 self.name = name 

70 self.is_active = is_active 

71 self.tool_id = tool_id 

72 message = f"Tool already exists with name: {name}" 

73 if not is_active: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 message += f" (currently inactive, ID: {tool_id})" 

75 super().__init__(message) 

76 

77 

78class ToolValidationError(ToolError): 

79 """Raised when tool validation fails.""" 

80 

81 

82class ToolInvocationError(ToolError): 

83 """Raised when tool invocation fails.""" 

84 

85 

86class ToolService: 

87 """Service for managing and invoking tools. 

88 

89 Handles: 

90 - Tool registration and deregistration. 

91 - Tool invocation and validation. 

92 - Tool federation. 

93 - Event notifications. 

94 - Active/inactive tool management. 

95 """ 

96 

97 def __init__(self): 

98 """Initialize the tool service.""" 

99 self._event_subscribers: List[asyncio.Queue] = [] 

100 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

101 

102 async def initialize(self) -> None: 

103 """Initialize the service.""" 

104 logger.info("Initializing tool service") 

105 

106 async def shutdown(self) -> None: 

107 """Shutdown the service.""" 

108 await self._http_client.aclose() 

109 logger.info("Tool service shutdown complete") 

110 

111 def _convert_tool_to_read(self, tool: DbTool) -> ToolRead: 

112 """ 

113 Converts a DbTool instance into a ToolRead model, including aggregated metrics and 

114 new API gateway fields: request_type and authentication credentials (masked). 

115 

116 Args: 

117 tool (DbTool): The ORM instance of the tool. 

118 

119 Returns: 

120 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields. 

121 """ 

122 tool_dict = tool.__dict__.copy() 

123 tool_dict.pop("_sa_instance_state", None) 

124 tool_dict["execution_count"] = tool.execution_count 

125 tool_dict["metrics"] = tool.metrics_summary 

126 tool_dict["request_type"] = tool.request_type 

127 

128 decoded_auth_value = decode_auth(tool.auth_value) 

129 if tool.auth_type == "basic": 

130 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1]) 

131 username, password = decoded_bytes.decode("utf-8").split(":") 

132 tool_dict["auth"] = { 

133 "auth_type": "basic", 

134 "username": username, 

135 "password": "********" if password else None, 

136 } 

137 elif tool.auth_type == "bearer": 

138 tool_dict["auth"] = { 

139 "auth_type": "bearer", 

140 "token": "********" if decoded_auth_value["Authorization"] else None, 

141 } 

142 elif tool.auth_type == "authheaders": 

143 tool_dict["auth"] = { 

144 "auth_type": "authheaders", 

145 "auth_header_key": next(iter(decoded_auth_value)), 

146 "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None, 

147 } 

148 else: 

149 tool_dict["auth"] = None 

150 return ToolRead.model_validate(tool_dict) 

151 

152 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None: 

153 """ 

154 Records a metric for a tool invocation. 

155 

156 This function calculates the response time using the provided start time and records 

157 the metric details (including whether the invocation was successful and any error message) 

158 into the database. The metric is then committed to the database. 

159 

160 Args: 

161 db (Session): The SQLAlchemy database session. 

162 tool (DbTool): The tool that was invoked. 

163 start_time (float): The monotonic start time of the invocation. 

164 success (bool): True if the invocation succeeded; otherwise, False. 

165 error_message (Optional[str]): The error message if the invocation failed, otherwise None. 

166 """ 

167 end_time = time.monotonic() 

168 response_time = end_time - start_time 

169 metric = ToolMetric( 

170 tool_id=tool.id, 

171 response_time=response_time, 

172 is_success=success, 

173 error_message=error_message, 

174 ) 

175 db.add(metric) 

176 db.commit() 

177 

178 async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead: 

179 """Register a new tool. 

180 

181 Args: 

182 db: Database session. 

183 tool: Tool creation schema. 

184 

185 Returns: 

186 Created tool information. 

187 

188 Raises: 

189 ToolNameConflictError: If tool name already exists. 

190 ToolError: For other tool registration errors. 

191 """ 

192 try: 

193 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name)).scalar_one_or_none() 

194 if existing_tool: 

195 raise ToolNameConflictError( 

196 tool.name, 

197 is_active=existing_tool.is_active, 

198 tool_id=existing_tool.id, 

199 ) 

200 

201 if tool.auth is None: 201 ↛ 205line 201 didn't jump to line 205 because the condition on line 201 was always true

202 auth_type = None 

203 auth_value = None 

204 else: 

205 auth_type = tool.auth.auth_type 

206 auth_value = tool.auth.auth_value 

207 

208 db_tool = DbTool( 

209 name=tool.name, 

210 url=str(tool.url), 

211 description=tool.description, 

212 integration_type=tool.integration_type, 

213 request_type=tool.request_type, 

214 headers=tool.headers, 

215 input_schema=tool.input_schema, 

216 jsonpath_filter=tool.jsonpath_filter, 

217 auth_type=auth_type, 

218 auth_value=auth_value, 

219 gateway_id=tool.gateway_id, 

220 ) 

221 db.add(db_tool) 

222 db.commit() 

223 db.refresh(db_tool) 

224 await self._notify_tool_added(db_tool) 

225 logger.info(f"Registered tool: {tool.name}") 

226 return self._convert_tool_to_read(db_tool) 

227 except IntegrityError: 

228 db.rollback() 

229 raise ToolError(f"Tool already exists: {tool.name}") 

230 except Exception as e: 

231 db.rollback() 

232 raise ToolError(f"Failed to register tool: {str(e)}") 

233 

234 async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: 

235 """ 

236 Retrieve a list of registered tools from the database. 

237 

238 Args: 

239 db (Session): The SQLAlchemy database session. 

240 include_inactive (bool): If True, include inactive tools in the result. 

241 Defaults to False. 

242 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, 

243 this parameter is ignored. Defaults to None. 

244 

245 Returns: 

246 List[ToolRead]: A list of registered tools represented as ToolRead objects. 

247 """ 

248 query = select(DbTool) 

249 cursor = None # Placeholder for pagination; ignore for now 

250 logger.debug(f"Listing tools with include_inactive={include_inactive}, cursor={cursor}") 

251 if not include_inactive: 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true

252 query = query.where(DbTool.is_active) 

253 tools = db.execute(query).scalars().all() 

254 return [self._convert_tool_to_read(t) for t in tools] 

255 

256 async def list_server_tools(self, db: Session, server_id: int, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: 

257 """ 

258 Retrieve a list of registered tools from the database. 

259 

260 Args: 

261 db (Session): The SQLAlchemy database session. 

262 server_id (int): Server ID 

263 include_inactive (bool): If True, include inactive tools in the result. 

264 Defaults to False. 

265 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, 

266 this parameter is ignored. Defaults to None. 

267 

268 Returns: 

269 List[ToolRead]: A list of registered tools represented as ToolRead objects. 

270 """ 

271 query = select(DbTool).join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id) 

272 cursor = None # Placeholder for pagination; ignore for now 

273 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}") 

274 if not include_inactive: 

275 query = query.where(DbTool.is_active) 

276 tools = db.execute(query).scalars().all() 

277 return [self._convert_tool_to_read(t) for t in tools] 

278 

279 async def get_tool(self, db: Session, tool_id: int) -> ToolRead: 

280 """Get a specific tool by ID. 

281 

282 Args: 

283 db: Database session. 

284 tool_id: Tool ID to retrieve. 

285 

286 Returns: 

287 Tool information. 

288 

289 Raises: 

290 ToolNotFoundError: If tool not found. 

291 """ 

292 tool = db.get(DbTool, tool_id) 

293 if not tool: 

294 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

295 return self._convert_tool_to_read(tool) 

296 

297 async def delete_tool(self, db: Session, tool_id: int) -> None: 

298 """Permanently delete a tool from the database. 

299 

300 Args: 

301 db: Database session. 

302 tool_id: Tool ID to delete. 

303 

304 Raises: 

305 ToolNotFoundError: If tool not found. 

306 ToolError: For other deletion errors. 

307 """ 

308 try: 

309 tool = db.get(DbTool, tool_id) 

310 if not tool: 

311 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

312 tool_info = {"id": tool.id, "name": tool.name} 

313 db.delete(tool) 

314 db.commit() 

315 await self._notify_tool_deleted(tool_info) 

316 logger.info(f"Permanently deleted tool: {tool_info['name']}") 

317 except Exception as e: 

318 db.rollback() 

319 raise ToolError(f"Failed to delete tool: {str(e)}") 

320 

321 async def toggle_tool_status(self, db: Session, tool_id: int, activate: bool) -> ToolRead: 

322 """Toggle tool active status. 

323 

324 Args: 

325 db: Database session. 

326 tool_id: Tool ID to toggle. 

327 activate: True to activate, False to deactivate. 

328 

329 Returns: 

330 Updated tool information. 

331 

332 Raises: 

333 ToolNotFoundError: If tool not found. 

334 ToolError: For other errors. 

335 """ 

336 try: 

337 tool = db.get(DbTool, tool_id) 

338 if not tool: 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true

339 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

340 if tool.is_active != activate: 340 ↛ 350line 340 didn't jump to line 350 because the condition on line 340 was always true

341 tool.is_active = activate 

342 tool.updated_at = datetime.utcnow() 

343 db.commit() 

344 db.refresh(tool) 

345 if activate: 345 ↛ 346line 345 didn't jump to line 346 because the condition on line 345 was never true

346 await self._notify_tool_activated(tool) 

347 else: 

348 await self._notify_tool_deactivated(tool) 

349 logger.info(f"Tool {tool.name} {'activated' if activate else 'deactivated'}") 

350 return self._convert_tool_to_read(tool) 

351 except Exception as e: 

352 db.rollback() 

353 raise ToolError(f"Failed to toggle tool status: {str(e)}") 

354 

355 # async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult: 

356 # """ 

357 # Invoke a registered tool and record execution metrics. 

358 

359 # Args: 

360 # db: Database session. 

361 # name: Name of tool to invoke. 

362 # arguments: Tool arguments. 

363 

364 # Returns: 

365 # Tool invocation result. 

366 

367 # Raises: 

368 # ToolNotFoundError: If tool not found. 

369 # ToolInvocationError: If invocation fails. 

370 # """ 

371 

372 # tool = db.execute(select(DbTool).where(DbTool.name == name).where(DbTool.is_active)).scalar_one_or_none() 

373 # if not tool: 

374 # inactive_tool = db.execute(select(DbTool).where(DbTool.name == name).where(not_(DbTool.is_active))).scalar_one_or_none() 

375 # if inactive_tool: 

376 # raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

377 # raise ToolNotFoundError(f"Tool not found: {name}") 

378 # start_time = time.monotonic() 

379 # success = False 

380 # error_message = None 

381 # try: 

382 # # tool.validate_arguments(arguments) 

383 # # Build headers with auth if necessary. 

384 # headers = tool.headers or {} 

385 # if tool.integration_type == "REST": 

386 # credentials = decode_auth(tool.auth_value) 

387 # headers.update(credentials) 

388 

389 # # Build the payload based on integration type. 

390 # payload = arguments 

391 

392 # # Use the tool's request_type rather than defaulting to POST. 

393 # method = tool.request_type.upper() 

394 # if method == "GET": 

395 # response = await self._http_client.get(tool.url, params=payload, headers=headers) 

396 # else: 

397 # response = await self._http_client.request(method, tool.url, json=payload, headers=headers) 

398 # response.raise_for_status() 

399 # result = response.json() 

400 

401 # if response.status_code not in [200, 201, 202, 204, 206]: 

402 # tool_result = ToolResult( 

403 # content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")], 

404 # is_error=True, 

405 # ) 

406 # else: 

407 # filtered_response = extract_using_jq(result, tool.jsonpath_filter) 

408 # tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))]) 

409 

410 # success = True 

411 # elif tool.integration_type == "MCP": 

412 # gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.is_active)).scalar_one_or_none() 

413 # if gateway.auth_type == "bearer": 

414 # headers = decode_auth(gateway.auth_value) 

415 # else: 

416 # headers = {} 

417 

418 # async def connect_to_sse_server(server_url: str): 

419 # """ 

420 # Connect to an MCP server running with SSE transport 

421 

422 # Args: 

423 # server_url: Server URL 

424 

425 # Returns: 

426 # str: Tool call result 

427 # """ 

428 # # Store the context managers so they stay alive 

429 # _streams_context = sse_client(url=server_url, headers=headers) 

430 # streams = await _streams_context.__aenter__() #line 422 

431 

432 # _session_context = ClientSession(*streams) 

433 # session: ClientSession = await _session_context.__aenter__() #line 425 

434 

435 # # Initialize 

436 # await session.initialize() 

437 # tool_call_result = await session.call_tool(name, arguments) 

438 

439 # await _session_context.__aexit__(None, None, None) 

440 # await _streams_context.__aexit__(None, None, None) #line 432 

441 

442 # return tool_call_result 

443 

444 # tool_gateway_id = tool.gateway_id 

445 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.is_active)).scalar_one_or_none() 

446 

447 # tool_call_result = await connect_to_sse_server(tool_gateway.url) 

448 # content = tool_call_result.model_dump(by_alias=True).get("content", []) 

449 

450 # success = True 

451 # filtered_response = extract_using_jq(content, tool.jsonpath_filter) 

452 # tool_result = ToolResult(content=filtered_response) 

453 # else: 

454 # return ToolResult(content="Invalid tool type") 

455 

456 # return tool_result 

457 # except Exception as e: 

458 # error_message = str(e) 

459 # raise ToolInvocationError(f"Tool invocation failed: {error_message}") 

460 # finally: 

461 # await self._record_tool_metric(db, tool, start_time, success, error_message) 

462 

463 async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult: 

464 """ 

465 Invoke a registered tool and record execution metrics. 

466 

467 Args: 

468 db: Database session. 

469 name: Name of tool to invoke. 

470 arguments: Tool arguments. 

471 

472 Returns: 

473 Tool invocation result. 

474 

475 Raises: 

476 ToolNotFoundError: If tool not found. 

477 ToolInvocationError: If invocation fails. 

478 """ 

479 tool = db.execute(select(DbTool).where(DbTool.name == name).where(DbTool.is_active)).scalar_one_or_none() 

480 if not tool: 

481 inactive_tool = db.execute(select(DbTool).where(DbTool.name == name).where(not_(DbTool.is_active))).scalar_one_or_none() 

482 if inactive_tool: 

483 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

484 raise ToolNotFoundError(f"Tool not found: {name}") 

485 start_time = time.monotonic() 

486 success = False 

487 error_message = None 

488 try: 

489 # tool.validate_arguments(arguments) 

490 # Build headers with auth if necessary. 

491 headers = tool.headers or {} 

492 if tool.integration_type == "REST": 492 ↛ 538line 492 didn't jump to line 538 because the condition on line 492 was always true

493 credentials = decode_auth(tool.auth_value) 

494 headers.update(credentials) 

495 

496 # Build the payload based on integration type. 

497 payload = arguments.copy() 

498 

499 # Handle URL path parameter substitution 

500 final_url = tool.url 

501 if "{" in tool.url and "}" in tool.url: 501 ↛ 503line 501 didn't jump to line 503 because the condition on line 501 was never true

502 # Extract path parameters from URL template and arguments 

503 import re 

504 

505 url_params = re.findall(r"\{(\w+)\}", tool.url) 

506 url_substitutions = {} 

507 

508 for param in url_params: 

509 if param in payload: 

510 url_substitutions[param] = payload.pop(param) # Remove from payload 

511 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param])) 

512 else: 

513 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments") 

514 

515 # Use the tool's request_type rather than defaulting to POST. 

516 method = tool.request_type.upper() 

517 if method == "GET": 517 ↛ 518line 517 didn't jump to line 518 because the condition on line 517 was never true

518 response = await self._http_client.get(final_url, params=payload, headers=headers) 

519 else: 

520 response = await self._http_client.request(method, final_url, json=payload, headers=headers) 

521 response.raise_for_status() 

522 

523 # Handle 204 No Content responses that have no body 

524 if response.status_code == 204: 524 ↛ 525line 524 didn't jump to line 525 because the condition on line 524 was never true

525 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")]) 

526 elif response.status_code not in [200, 201, 202, 206]: 526 ↛ 527line 526 didn't jump to line 527 because the condition on line 526 was never true

527 result = response.json() 

528 tool_result = ToolResult( 

529 content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")], 

530 is_error=True, 

531 ) 

532 else: 

533 result = response.json() 

534 filtered_response = extract_using_jq(result, tool.jsonpath_filter) 

535 tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))]) 

536 

537 success = True 

538 elif tool.integration_type == "MCP": 

539 transport = tool.request_type.lower() 

540 gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.is_active)).scalar_one_or_none() 

541 if gateway.auth_type == "bearer": 

542 headers = decode_auth(gateway.auth_value) 

543 else: 

544 headers = {} 

545 

546 async def connect_to_sse_server(server_url: str) -> str: 

547 """ 

548 Connect to an MCP server running with SSE transport 

549 

550 Args: 

551 server_url (str): MCP Server SSE URL 

552 

553 Returns: 

554 str: Result of tool call 

555 """ 

556 # Use async with directly to manage the context 

557 async with sse_client(url=server_url, headers=headers) as streams: 

558 async with ClientSession(*streams) as session: 

559 # Initialize the session 

560 await session.initialize() 

561 tool_call_result = await session.call_tool(name, arguments) 

562 return tool_call_result 

563 

564 async def connect_to_streamablehttp_server(server_url: str) -> str: 

565 """ 

566 Connect to an MCP server running with Streamable HTTP transport 

567 

568 Args: 

569 server_url (str): MCP Server URL 

570 

571 Returns: 

572 str: Result of tool call 

573 """ 

574 # Use async with directly to manage the context 

575 async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, get_session_id): 

576 async with ClientSession(read_stream, write_stream) as session: 

577 # Initialize the session 

578 await session.initialize() 

579 tool_call_result = await session.call_tool(name, arguments) 

580 return tool_call_result 

581 

582 tool_gateway_id = tool.gateway_id 

583 tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.is_active)).scalar_one_or_none() 

584 

585 if transport == "sse": 

586 tool_call_result = await connect_to_sse_server(tool_gateway.url) 

587 elif transport == "streamablehttp": 

588 tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url) 

589 content = tool_call_result.model_dump(by_alias=True).get("content", []) 

590 

591 success = True 

592 filtered_response = extract_using_jq(content, tool.jsonpath_filter) 

593 tool_result = ToolResult(content=filtered_response) 

594 else: 

595 return ToolResult(content="Invalid tool type") 

596 

597 return tool_result 

598 except Exception as e: 

599 error_message = str(e) 

600 raise ToolInvocationError(f"Tool invocation failed: {error_message}") 

601 finally: 

602 await self._record_tool_metric(db, tool, start_time, success, error_message) 

603 

604 async def update_tool(self, db: Session, tool_id: int, tool_update: ToolUpdate) -> ToolRead: 

605 """Update an existing tool. 

606 

607 Args: 

608 db: Database session. 

609 tool_id: ID of tool to update. 

610 tool_update: Updated tool data. 

611 

612 Returns: 

613 Updated tool information. 

614 

615 Raises: 

616 ToolNotFoundError: If tool not found. 

617 ToolError: For other tool update errors. 

618 ToolNameConflictError: If tool name conflict occurs 

619 """ 

620 try: 

621 tool = db.get(DbTool, tool_id) 

622 if not tool: 

623 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

624 if tool_update.name is not None and tool_update.name != tool.name: 624 ↛ 633line 624 didn't jump to line 633 because the condition on line 624 was always true

625 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool_update.name).where(DbTool.id != tool_id)).scalar_one_or_none() 

626 if existing_tool: 

627 raise ToolNameConflictError( 

628 tool_update.name, 

629 is_active=existing_tool.is_active, 

630 tool_id=existing_tool.id, 

631 ) 

632 

633 if tool_update.name is not None: 633 ↛ 635line 633 didn't jump to line 635 because the condition on line 633 was always true

634 tool.name = tool_update.name 

635 if tool_update.url is not None: 635 ↛ 637line 635 didn't jump to line 637 because the condition on line 635 was always true

636 tool.url = str(tool_update.url) 

637 if tool_update.description is not None: 637 ↛ 639line 637 didn't jump to line 639 because the condition on line 637 was always true

638 tool.description = tool_update.description 

639 if tool_update.integration_type is not None: 639 ↛ 640line 639 didn't jump to line 640 because the condition on line 639 was never true

640 tool.integration_type = tool_update.integration_type 

641 if tool_update.request_type is not None: 641 ↛ 642line 641 didn't jump to line 642 because the condition on line 641 was never true

642 tool.request_type = tool_update.request_type 

643 if tool_update.headers is not None: 643 ↛ 644line 643 didn't jump to line 644 because the condition on line 643 was never true

644 tool.headers = tool_update.headers 

645 if tool_update.input_schema is not None: 645 ↛ 646line 645 didn't jump to line 646 because the condition on line 645 was never true

646 tool.input_schema = tool_update.input_schema 

647 if tool_update.jsonpath_filter is not None: 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true

648 tool.jsonpath_filter = tool_update.jsonpath_filter 

649 

650 if tool_update.auth is not None: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true

651 if tool_update.auth.auth_type is not None: 

652 tool.auth_type = tool_update.auth.auth_type 

653 if tool_update.auth.auth_value is not None: 

654 tool.auth_value = tool_update.auth.auth_value 

655 else: 

656 tool.auth_type = None 

657 

658 tool.updated_at = datetime.utcnow() 

659 db.commit() 

660 db.refresh(tool) 

661 await self._notify_tool_updated(tool) 

662 logger.info(f"Updated tool: {tool.name}") 

663 return self._convert_tool_to_read(tool) 

664 except Exception as e: 

665 db.rollback() 

666 raise ToolError(f"Failed to update tool: {str(e)}") 

667 

668 async def _notify_tool_updated(self, tool: DbTool) -> None: 

669 """ 

670 Notify subscribers of tool update. 

671 

672 Args: 

673 tool: Tool updated 

674 """ 

675 event = { 

676 "type": "tool_updated", 

677 "data": { 

678 "id": tool.id, 

679 "name": tool.name, 

680 "url": tool.url, 

681 "description": tool.description, 

682 "is_active": tool.is_active, 

683 }, 

684 "timestamp": datetime.utcnow().isoformat(), 

685 } 

686 await self._publish_event(event) 

687 

688 async def _notify_tool_activated(self, tool: DbTool) -> None: 

689 """ 

690 Notify subscribers of tool activation. 

691 

692 Args: 

693 tool: Tool activated 

694 """ 

695 event = { 

696 "type": "tool_activated", 

697 "data": {"id": tool.id, "name": tool.name, "is_active": True}, 

698 "timestamp": datetime.utcnow().isoformat(), 

699 } 

700 await self._publish_event(event) 

701 

702 async def _notify_tool_deactivated(self, tool: DbTool) -> None: 

703 """ 

704 Notify subscribers of tool deactivation. 

705 

706 Args: 

707 tool: Tool deactivated 

708 """ 

709 event = { 

710 "type": "tool_deactivated", 

711 "data": {"id": tool.id, "name": tool.name, "is_active": False}, 

712 "timestamp": datetime.utcnow().isoformat(), 

713 } 

714 await self._publish_event(event) 

715 

716 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None: 

717 """ 

718 Notify subscribers of tool deletion. 

719 

720 Args: 

721 tool_info: Dictionary on tool deleted 

722 """ 

723 event = { 

724 "type": "tool_deleted", 

725 "data": tool_info, 

726 "timestamp": datetime.utcnow().isoformat(), 

727 } 

728 await self._publish_event(event) 

729 

730 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

731 """Subscribe to tool events. 

732 

733 Yields: 

734 Tool event messages. 

735 """ 

736 queue: asyncio.Queue = asyncio.Queue() 

737 self._event_subscribers.append(queue) 

738 try: 

739 while True: 

740 event = await queue.get() 

741 yield event 

742 finally: 

743 self._event_subscribers.remove(queue) 

744 

745 async def _notify_tool_added(self, tool: DbTool) -> None: 

746 """ 

747 Notify subscribers of tool addition. 

748 

749 Args: 

750 tool: Tool added 

751 """ 

752 event = { 

753 "type": "tool_added", 

754 "data": { 

755 "id": tool.id, 

756 "name": tool.name, 

757 "url": tool.url, 

758 "description": tool.description, 

759 "is_active": tool.is_active, 

760 }, 

761 "timestamp": datetime.utcnow().isoformat(), 

762 } 

763 await self._publish_event(event) 

764 

765 async def _notify_tool_removed(self, tool: DbTool) -> None: 

766 """ 

767 Notify subscribers of tool removal (soft delete/deactivation). 

768 

769 Args: 

770 tool: Tool removed 

771 """ 

772 event = { 

773 "type": "tool_removed", 

774 "data": {"id": tool.id, "name": tool.name, "is_active": False}, 

775 "timestamp": datetime.utcnow().isoformat(), 

776 } 

777 await self._publish_event(event) 

778 

779 async def _publish_event(self, event: Dict[str, Any]) -> None: 

780 """ 

781 Publish event to all subscribers. 

782 

783 Args: 

784 event: Event to publish 

785 """ 

786 for queue in self._event_subscribers: 

787 await queue.put(event) 

788 

789 async def _validate_tool_url(self, url: str) -> None: 

790 """Validate tool URL is accessible. 

791 

792 Args: 

793 url: URL to validate. 

794 

795 Raises: 

796 ToolValidationError: If URL validation fails. 

797 """ 

798 try: 

799 response = await self._http_client.get(url) 

800 response.raise_for_status() 

801 except Exception as e: 

802 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}") 

803 

804 async def _check_tool_health(self, tool: DbTool) -> bool: 

805 """Check if tool endpoint is healthy. 

806 

807 Args: 

808 tool: Tool to check. 

809 

810 Returns: 

811 True if tool is healthy. 

812 """ 

813 try: 

814 response = await self._http_client.get(tool.url) 

815 return response.is_success 

816 except Exception: 

817 return False 

818 

819 async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]: 

820 """Generate tool events for SSE. 

821 

822 Yields: 

823 Tool events. 

824 """ 

825 queue: asyncio.Queue = asyncio.Queue() 

826 self._event_subscribers.append(queue) 

827 try: 

828 while True: 

829 event = await queue.get() 

830 yield event 

831 finally: 

832 self._event_subscribers.remove(queue) 

833 

834 # --- Metrics --- 

835 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]: 

836 """ 

837 Aggregate metrics for all tool invocations. 

838 

839 Args: 

840 db: Database session 

841 

842 Returns: 

843 A dictionary with keys: 

844 - total_executions 

845 - successful_executions 

846 - failed_executions 

847 - failure_rate 

848 - min_response_time 

849 - max_response_time 

850 - avg_response_time 

851 - last_execution_time 

852 """ 

853 

854 total = db.execute(select(func.count(ToolMetric.id))).scalar() or 0 # pylint: disable=not-callable 

855 successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success)).scalar() or 0 # pylint: disable=not-callable 

856 failed = db.execute(select(func.count(ToolMetric.id)).where(not_(ToolMetric.is_success))).scalar() or 0 # pylint: disable=not-callable 

857 failure_rate = failed / total if total > 0 else 0.0 

858 min_rt = db.execute(select(func.min(ToolMetric.response_time))).scalar() 

859 max_rt = db.execute(select(func.max(ToolMetric.response_time))).scalar() 

860 avg_rt = db.execute(select(func.avg(ToolMetric.response_time))).scalar() 

861 last_time = db.execute(select(func.max(ToolMetric.timestamp))).scalar() 

862 

863 return { 

864 "total_executions": total, 

865 "successful_executions": successful, 

866 "failed_executions": failed, 

867 "failure_rate": failure_rate, 

868 "min_response_time": min_rt, 

869 "max_response_time": max_rt, 

870 "avg_response_time": avg_rt, 

871 "last_execution_time": last_time, 

872 } 

873 

874 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None: 

875 """ 

876 Reset metrics for tool invocations. 

877 

878 If tool_id is provided, only the metrics for that specific tool will be deleted. 

879 Otherwise, all tool metrics will be deleted (global reset). 

880 

881 Args: 

882 db (Session): The SQLAlchemy database session. 

883 tool_id (Optional[int]): Specific tool ID to reset metrics for. 

884 """ 

885 

886 if tool_id: 

887 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id)) 

888 else: 

889 db.execute(delete(ToolMetric)) 

890 db.commit()