Coverage for mcpgateway/services/tool_service.py: 55%
321 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
1# -*- coding: utf-8 -*-
2"""Tool Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements tool management and invocation according to the MCP specification.
9It handles:
10- Tool registration and validation
11- Tool invocation with schema validation
12- Tool federation across gateways
13- Event notifications for tool changes
14- Active/inactive tool management
15"""
17import asyncio
18import base64
19import json
20import logging
21import time
22from datetime import datetime
23from typing import Any, AsyncGenerator, Dict, List, Optional
25import httpx
26from mcp import ClientSession
27from mcp.client.sse import sse_client
28from mcp.client.streamable_http import streamablehttp_client
29from sqlalchemy import delete, func, not_, select
30from sqlalchemy.exc import IntegrityError
31from sqlalchemy.orm import Session
33from mcpgateway.config import settings
34from mcpgateway.db import Gateway as DbGateway
35from mcpgateway.db import Tool as DbTool
36from mcpgateway.db import ToolMetric, server_tool_association
37from mcpgateway.schemas import (
38 ToolCreate,
39 ToolRead,
40 ToolUpdate,
41)
42from mcpgateway.types import TextContent, ToolResult
43from mcpgateway.utils.services_auth import decode_auth
45from ..config import extract_using_jq
47logger = logging.getLogger(__name__)
50class ToolError(Exception):
51 """Base class for tool-related errors."""
54class ToolNotFoundError(ToolError):
55 """Raised when a requested tool is not found."""
58class ToolNameConflictError(ToolError):
59 """Raised when a tool name conflicts with existing (active or inactive) tool."""
61 def __init__(self, name: str, is_active: bool = True, tool_id: Optional[int] = None):
62 """Initialize the error with tool information.
64 Args:
65 name: The conflicting tool name.
66 is_active: Whether the existing tool is active.
67 tool_id: ID of the existing tool if available.
68 """
69 self.name = name
70 self.is_active = is_active
71 self.tool_id = tool_id
72 message = f"Tool already exists with name: {name}"
73 if not is_active: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 message += f" (currently inactive, ID: {tool_id})"
75 super().__init__(message)
78class ToolValidationError(ToolError):
79 """Raised when tool validation fails."""
82class ToolInvocationError(ToolError):
83 """Raised when tool invocation fails."""
86class ToolService:
87 """Service for managing and invoking tools.
89 Handles:
90 - Tool registration and deregistration.
91 - Tool invocation and validation.
92 - Tool federation.
93 - Event notifications.
94 - Active/inactive tool management.
95 """
97 def __init__(self):
98 """Initialize the tool service."""
99 self._event_subscribers: List[asyncio.Queue] = []
100 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
102 async def initialize(self) -> None:
103 """Initialize the service."""
104 logger.info("Initializing tool service")
106 async def shutdown(self) -> None:
107 """Shutdown the service."""
108 await self._http_client.aclose()
109 logger.info("Tool service shutdown complete")
111 def _convert_tool_to_read(self, tool: DbTool) -> ToolRead:
112 """
113 Converts a DbTool instance into a ToolRead model, including aggregated metrics and
114 new API gateway fields: request_type and authentication credentials (masked).
116 Args:
117 tool (DbTool): The ORM instance of the tool.
119 Returns:
120 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields.
121 """
122 tool_dict = tool.__dict__.copy()
123 tool_dict.pop("_sa_instance_state", None)
124 tool_dict["execution_count"] = tool.execution_count
125 tool_dict["metrics"] = tool.metrics_summary
126 tool_dict["request_type"] = tool.request_type
128 decoded_auth_value = decode_auth(tool.auth_value)
129 if tool.auth_type == "basic":
130 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1])
131 username, password = decoded_bytes.decode("utf-8").split(":")
132 tool_dict["auth"] = {
133 "auth_type": "basic",
134 "username": username,
135 "password": "********" if password else None,
136 }
137 elif tool.auth_type == "bearer":
138 tool_dict["auth"] = {
139 "auth_type": "bearer",
140 "token": "********" if decoded_auth_value["Authorization"] else None,
141 }
142 elif tool.auth_type == "authheaders":
143 tool_dict["auth"] = {
144 "auth_type": "authheaders",
145 "auth_header_key": next(iter(decoded_auth_value)),
146 "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None,
147 }
148 else:
149 tool_dict["auth"] = None
150 return ToolRead.model_validate(tool_dict)
152 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None:
153 """
154 Records a metric for a tool invocation.
156 This function calculates the response time using the provided start time and records
157 the metric details (including whether the invocation was successful and any error message)
158 into the database. The metric is then committed to the database.
160 Args:
161 db (Session): The SQLAlchemy database session.
162 tool (DbTool): The tool that was invoked.
163 start_time (float): The monotonic start time of the invocation.
164 success (bool): True if the invocation succeeded; otherwise, False.
165 error_message (Optional[str]): The error message if the invocation failed, otherwise None.
166 """
167 end_time = time.monotonic()
168 response_time = end_time - start_time
169 metric = ToolMetric(
170 tool_id=tool.id,
171 response_time=response_time,
172 is_success=success,
173 error_message=error_message,
174 )
175 db.add(metric)
176 db.commit()
178 async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead:
179 """Register a new tool.
181 Args:
182 db: Database session.
183 tool: Tool creation schema.
185 Returns:
186 Created tool information.
188 Raises:
189 ToolNameConflictError: If tool name already exists.
190 ToolError: For other tool registration errors.
191 """
192 try:
193 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name)).scalar_one_or_none()
194 if existing_tool:
195 raise ToolNameConflictError(
196 tool.name,
197 is_active=existing_tool.is_active,
198 tool_id=existing_tool.id,
199 )
201 if tool.auth is None: 201 ↛ 205line 201 didn't jump to line 205 because the condition on line 201 was always true
202 auth_type = None
203 auth_value = None
204 else:
205 auth_type = tool.auth.auth_type
206 auth_value = tool.auth.auth_value
208 db_tool = DbTool(
209 name=tool.name,
210 url=str(tool.url),
211 description=tool.description,
212 integration_type=tool.integration_type,
213 request_type=tool.request_type,
214 headers=tool.headers,
215 input_schema=tool.input_schema,
216 jsonpath_filter=tool.jsonpath_filter,
217 auth_type=auth_type,
218 auth_value=auth_value,
219 gateway_id=tool.gateway_id,
220 )
221 db.add(db_tool)
222 db.commit()
223 db.refresh(db_tool)
224 await self._notify_tool_added(db_tool)
225 logger.info(f"Registered tool: {tool.name}")
226 return self._convert_tool_to_read(db_tool)
227 except IntegrityError:
228 db.rollback()
229 raise ToolError(f"Tool already exists: {tool.name}")
230 except Exception as e:
231 db.rollback()
232 raise ToolError(f"Failed to register tool: {str(e)}")
234 async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]:
235 """
236 Retrieve a list of registered tools from the database.
238 Args:
239 db (Session): The SQLAlchemy database session.
240 include_inactive (bool): If True, include inactive tools in the result.
241 Defaults to False.
242 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
243 this parameter is ignored. Defaults to None.
245 Returns:
246 List[ToolRead]: A list of registered tools represented as ToolRead objects.
247 """
248 query = select(DbTool)
249 cursor = None # Placeholder for pagination; ignore for now
250 logger.debug(f"Listing tools with include_inactive={include_inactive}, cursor={cursor}")
251 if not include_inactive: 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true
252 query = query.where(DbTool.is_active)
253 tools = db.execute(query).scalars().all()
254 return [self._convert_tool_to_read(t) for t in tools]
256 async def list_server_tools(self, db: Session, server_id: int, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]:
257 """
258 Retrieve a list of registered tools from the database.
260 Args:
261 db (Session): The SQLAlchemy database session.
262 server_id (int): Server ID
263 include_inactive (bool): If True, include inactive tools in the result.
264 Defaults to False.
265 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
266 this parameter is ignored. Defaults to None.
268 Returns:
269 List[ToolRead]: A list of registered tools represented as ToolRead objects.
270 """
271 query = select(DbTool).join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id)
272 cursor = None # Placeholder for pagination; ignore for now
273 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}")
274 if not include_inactive:
275 query = query.where(DbTool.is_active)
276 tools = db.execute(query).scalars().all()
277 return [self._convert_tool_to_read(t) for t in tools]
279 async def get_tool(self, db: Session, tool_id: int) -> ToolRead:
280 """Get a specific tool by ID.
282 Args:
283 db: Database session.
284 tool_id: Tool ID to retrieve.
286 Returns:
287 Tool information.
289 Raises:
290 ToolNotFoundError: If tool not found.
291 """
292 tool = db.get(DbTool, tool_id)
293 if not tool:
294 raise ToolNotFoundError(f"Tool not found: {tool_id}")
295 return self._convert_tool_to_read(tool)
297 async def delete_tool(self, db: Session, tool_id: int) -> None:
298 """Permanently delete a tool from the database.
300 Args:
301 db: Database session.
302 tool_id: Tool ID to delete.
304 Raises:
305 ToolNotFoundError: If tool not found.
306 ToolError: For other deletion errors.
307 """
308 try:
309 tool = db.get(DbTool, tool_id)
310 if not tool:
311 raise ToolNotFoundError(f"Tool not found: {tool_id}")
312 tool_info = {"id": tool.id, "name": tool.name}
313 db.delete(tool)
314 db.commit()
315 await self._notify_tool_deleted(tool_info)
316 logger.info(f"Permanently deleted tool: {tool_info['name']}")
317 except Exception as e:
318 db.rollback()
319 raise ToolError(f"Failed to delete tool: {str(e)}")
321 async def toggle_tool_status(self, db: Session, tool_id: int, activate: bool) -> ToolRead:
322 """Toggle tool active status.
324 Args:
325 db: Database session.
326 tool_id: Tool ID to toggle.
327 activate: True to activate, False to deactivate.
329 Returns:
330 Updated tool information.
332 Raises:
333 ToolNotFoundError: If tool not found.
334 ToolError: For other errors.
335 """
336 try:
337 tool = db.get(DbTool, tool_id)
338 if not tool: 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true
339 raise ToolNotFoundError(f"Tool not found: {tool_id}")
340 if tool.is_active != activate: 340 ↛ 350line 340 didn't jump to line 350 because the condition on line 340 was always true
341 tool.is_active = activate
342 tool.updated_at = datetime.utcnow()
343 db.commit()
344 db.refresh(tool)
345 if activate: 345 ↛ 346line 345 didn't jump to line 346 because the condition on line 345 was never true
346 await self._notify_tool_activated(tool)
347 else:
348 await self._notify_tool_deactivated(tool)
349 logger.info(f"Tool {tool.name} {'activated' if activate else 'deactivated'}")
350 return self._convert_tool_to_read(tool)
351 except Exception as e:
352 db.rollback()
353 raise ToolError(f"Failed to toggle tool status: {str(e)}")
355 # async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult:
356 # """
357 # Invoke a registered tool and record execution metrics.
359 # Args:
360 # db: Database session.
361 # name: Name of tool to invoke.
362 # arguments: Tool arguments.
364 # Returns:
365 # Tool invocation result.
367 # Raises:
368 # ToolNotFoundError: If tool not found.
369 # ToolInvocationError: If invocation fails.
370 # """
372 # tool = db.execute(select(DbTool).where(DbTool.name == name).where(DbTool.is_active)).scalar_one_or_none()
373 # if not tool:
374 # inactive_tool = db.execute(select(DbTool).where(DbTool.name == name).where(not_(DbTool.is_active))).scalar_one_or_none()
375 # if inactive_tool:
376 # raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
377 # raise ToolNotFoundError(f"Tool not found: {name}")
378 # start_time = time.monotonic()
379 # success = False
380 # error_message = None
381 # try:
382 # # tool.validate_arguments(arguments)
383 # # Build headers with auth if necessary.
384 # headers = tool.headers or {}
385 # if tool.integration_type == "REST":
386 # credentials = decode_auth(tool.auth_value)
387 # headers.update(credentials)
389 # # Build the payload based on integration type.
390 # payload = arguments
392 # # Use the tool's request_type rather than defaulting to POST.
393 # method = tool.request_type.upper()
394 # if method == "GET":
395 # response = await self._http_client.get(tool.url, params=payload, headers=headers)
396 # else:
397 # response = await self._http_client.request(method, tool.url, json=payload, headers=headers)
398 # response.raise_for_status()
399 # result = response.json()
401 # if response.status_code not in [200, 201, 202, 204, 206]:
402 # tool_result = ToolResult(
403 # content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")],
404 # is_error=True,
405 # )
406 # else:
407 # filtered_response = extract_using_jq(result, tool.jsonpath_filter)
408 # tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))])
410 # success = True
411 # elif tool.integration_type == "MCP":
412 # gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
413 # if gateway.auth_type == "bearer":
414 # headers = decode_auth(gateway.auth_value)
415 # else:
416 # headers = {}
418 # async def connect_to_sse_server(server_url: str):
419 # """
420 # Connect to an MCP server running with SSE transport
422 # Args:
423 # server_url: Server URL
425 # Returns:
426 # str: Tool call result
427 # """
428 # # Store the context managers so they stay alive
429 # _streams_context = sse_client(url=server_url, headers=headers)
430 # streams = await _streams_context.__aenter__() #line 422
432 # _session_context = ClientSession(*streams)
433 # session: ClientSession = await _session_context.__aenter__() #line 425
435 # # Initialize
436 # await session.initialize()
437 # tool_call_result = await session.call_tool(name, arguments)
439 # await _session_context.__aexit__(None, None, None)
440 # await _streams_context.__aexit__(None, None, None) #line 432
442 # return tool_call_result
444 # tool_gateway_id = tool.gateway_id
445 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
447 # tool_call_result = await connect_to_sse_server(tool_gateway.url)
448 # content = tool_call_result.model_dump(by_alias=True).get("content", [])
450 # success = True
451 # filtered_response = extract_using_jq(content, tool.jsonpath_filter)
452 # tool_result = ToolResult(content=filtered_response)
453 # else:
454 # return ToolResult(content="Invalid tool type")
456 # return tool_result
457 # except Exception as e:
458 # error_message = str(e)
459 # raise ToolInvocationError(f"Tool invocation failed: {error_message}")
460 # finally:
461 # await self._record_tool_metric(db, tool, start_time, success, error_message)
463 async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult:
464 """
465 Invoke a registered tool and record execution metrics.
467 Args:
468 db: Database session.
469 name: Name of tool to invoke.
470 arguments: Tool arguments.
472 Returns:
473 Tool invocation result.
475 Raises:
476 ToolNotFoundError: If tool not found.
477 ToolInvocationError: If invocation fails.
478 """
479 tool = db.execute(select(DbTool).where(DbTool.name == name).where(DbTool.is_active)).scalar_one_or_none()
480 if not tool:
481 inactive_tool = db.execute(select(DbTool).where(DbTool.name == name).where(not_(DbTool.is_active))).scalar_one_or_none()
482 if inactive_tool:
483 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
484 raise ToolNotFoundError(f"Tool not found: {name}")
485 start_time = time.monotonic()
486 success = False
487 error_message = None
488 try:
489 # tool.validate_arguments(arguments)
490 # Build headers with auth if necessary.
491 headers = tool.headers or {}
492 if tool.integration_type == "REST": 492 ↛ 538line 492 didn't jump to line 538 because the condition on line 492 was always true
493 credentials = decode_auth(tool.auth_value)
494 headers.update(credentials)
496 # Build the payload based on integration type.
497 payload = arguments.copy()
499 # Handle URL path parameter substitution
500 final_url = tool.url
501 if "{" in tool.url and "}" in tool.url: 501 ↛ 503line 501 didn't jump to line 503 because the condition on line 501 was never true
502 # Extract path parameters from URL template and arguments
503 import re
505 url_params = re.findall(r"\{(\w+)\}", tool.url)
506 url_substitutions = {}
508 for param in url_params:
509 if param in payload:
510 url_substitutions[param] = payload.pop(param) # Remove from payload
511 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param]))
512 else:
513 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments")
515 # Use the tool's request_type rather than defaulting to POST.
516 method = tool.request_type.upper()
517 if method == "GET": 517 ↛ 518line 517 didn't jump to line 518 because the condition on line 517 was never true
518 response = await self._http_client.get(final_url, params=payload, headers=headers)
519 else:
520 response = await self._http_client.request(method, final_url, json=payload, headers=headers)
521 response.raise_for_status()
523 # Handle 204 No Content responses that have no body
524 if response.status_code == 204: 524 ↛ 525line 524 didn't jump to line 525 because the condition on line 524 was never true
525 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")])
526 elif response.status_code not in [200, 201, 202, 206]: 526 ↛ 527line 526 didn't jump to line 527 because the condition on line 526 was never true
527 result = response.json()
528 tool_result = ToolResult(
529 content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")],
530 is_error=True,
531 )
532 else:
533 result = response.json()
534 filtered_response = extract_using_jq(result, tool.jsonpath_filter)
535 tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))])
537 success = True
538 elif tool.integration_type == "MCP":
539 transport = tool.request_type.lower()
540 gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
541 if gateway.auth_type == "bearer":
542 headers = decode_auth(gateway.auth_value)
543 else:
544 headers = {}
546 async def connect_to_sse_server(server_url: str) -> str:
547 """
548 Connect to an MCP server running with SSE transport
550 Args:
551 server_url (str): MCP Server SSE URL
553 Returns:
554 str: Result of tool call
555 """
556 # Use async with directly to manage the context
557 async with sse_client(url=server_url, headers=headers) as streams:
558 async with ClientSession(*streams) as session:
559 # Initialize the session
560 await session.initialize()
561 tool_call_result = await session.call_tool(name, arguments)
562 return tool_call_result
564 async def connect_to_streamablehttp_server(server_url: str) -> str:
565 """
566 Connect to an MCP server running with Streamable HTTP transport
568 Args:
569 server_url (str): MCP Server URL
571 Returns:
572 str: Result of tool call
573 """
574 # Use async with directly to manage the context
575 async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, get_session_id):
576 async with ClientSession(read_stream, write_stream) as session:
577 # Initialize the session
578 await session.initialize()
579 tool_call_result = await session.call_tool(name, arguments)
580 return tool_call_result
582 tool_gateway_id = tool.gateway_id
583 tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
585 if transport == "sse":
586 tool_call_result = await connect_to_sse_server(tool_gateway.url)
587 elif transport == "streamablehttp":
588 tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url)
589 content = tool_call_result.model_dump(by_alias=True).get("content", [])
591 success = True
592 filtered_response = extract_using_jq(content, tool.jsonpath_filter)
593 tool_result = ToolResult(content=filtered_response)
594 else:
595 return ToolResult(content="Invalid tool type")
597 return tool_result
598 except Exception as e:
599 error_message = str(e)
600 raise ToolInvocationError(f"Tool invocation failed: {error_message}")
601 finally:
602 await self._record_tool_metric(db, tool, start_time, success, error_message)
604 async def update_tool(self, db: Session, tool_id: int, tool_update: ToolUpdate) -> ToolRead:
605 """Update an existing tool.
607 Args:
608 db: Database session.
609 tool_id: ID of tool to update.
610 tool_update: Updated tool data.
612 Returns:
613 Updated tool information.
615 Raises:
616 ToolNotFoundError: If tool not found.
617 ToolError: For other tool update errors.
618 ToolNameConflictError: If tool name conflict occurs
619 """
620 try:
621 tool = db.get(DbTool, tool_id)
622 if not tool:
623 raise ToolNotFoundError(f"Tool not found: {tool_id}")
624 if tool_update.name is not None and tool_update.name != tool.name: 624 ↛ 633line 624 didn't jump to line 633 because the condition on line 624 was always true
625 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool_update.name).where(DbTool.id != tool_id)).scalar_one_or_none()
626 if existing_tool:
627 raise ToolNameConflictError(
628 tool_update.name,
629 is_active=existing_tool.is_active,
630 tool_id=existing_tool.id,
631 )
633 if tool_update.name is not None: 633 ↛ 635line 633 didn't jump to line 635 because the condition on line 633 was always true
634 tool.name = tool_update.name
635 if tool_update.url is not None: 635 ↛ 637line 635 didn't jump to line 637 because the condition on line 635 was always true
636 tool.url = str(tool_update.url)
637 if tool_update.description is not None: 637 ↛ 639line 637 didn't jump to line 639 because the condition on line 637 was always true
638 tool.description = tool_update.description
639 if tool_update.integration_type is not None: 639 ↛ 640line 639 didn't jump to line 640 because the condition on line 639 was never true
640 tool.integration_type = tool_update.integration_type
641 if tool_update.request_type is not None: 641 ↛ 642line 641 didn't jump to line 642 because the condition on line 641 was never true
642 tool.request_type = tool_update.request_type
643 if tool_update.headers is not None: 643 ↛ 644line 643 didn't jump to line 644 because the condition on line 643 was never true
644 tool.headers = tool_update.headers
645 if tool_update.input_schema is not None: 645 ↛ 646line 645 didn't jump to line 646 because the condition on line 645 was never true
646 tool.input_schema = tool_update.input_schema
647 if tool_update.jsonpath_filter is not None: 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true
648 tool.jsonpath_filter = tool_update.jsonpath_filter
650 if tool_update.auth is not None: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true
651 if tool_update.auth.auth_type is not None:
652 tool.auth_type = tool_update.auth.auth_type
653 if tool_update.auth.auth_value is not None:
654 tool.auth_value = tool_update.auth.auth_value
655 else:
656 tool.auth_type = None
658 tool.updated_at = datetime.utcnow()
659 db.commit()
660 db.refresh(tool)
661 await self._notify_tool_updated(tool)
662 logger.info(f"Updated tool: {tool.name}")
663 return self._convert_tool_to_read(tool)
664 except Exception as e:
665 db.rollback()
666 raise ToolError(f"Failed to update tool: {str(e)}")
668 async def _notify_tool_updated(self, tool: DbTool) -> None:
669 """
670 Notify subscribers of tool update.
672 Args:
673 tool: Tool updated
674 """
675 event = {
676 "type": "tool_updated",
677 "data": {
678 "id": tool.id,
679 "name": tool.name,
680 "url": tool.url,
681 "description": tool.description,
682 "is_active": tool.is_active,
683 },
684 "timestamp": datetime.utcnow().isoformat(),
685 }
686 await self._publish_event(event)
688 async def _notify_tool_activated(self, tool: DbTool) -> None:
689 """
690 Notify subscribers of tool activation.
692 Args:
693 tool: Tool activated
694 """
695 event = {
696 "type": "tool_activated",
697 "data": {"id": tool.id, "name": tool.name, "is_active": True},
698 "timestamp": datetime.utcnow().isoformat(),
699 }
700 await self._publish_event(event)
702 async def _notify_tool_deactivated(self, tool: DbTool) -> None:
703 """
704 Notify subscribers of tool deactivation.
706 Args:
707 tool: Tool deactivated
708 """
709 event = {
710 "type": "tool_deactivated",
711 "data": {"id": tool.id, "name": tool.name, "is_active": False},
712 "timestamp": datetime.utcnow().isoformat(),
713 }
714 await self._publish_event(event)
716 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None:
717 """
718 Notify subscribers of tool deletion.
720 Args:
721 tool_info: Dictionary on tool deleted
722 """
723 event = {
724 "type": "tool_deleted",
725 "data": tool_info,
726 "timestamp": datetime.utcnow().isoformat(),
727 }
728 await self._publish_event(event)
730 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
731 """Subscribe to tool events.
733 Yields:
734 Tool event messages.
735 """
736 queue: asyncio.Queue = asyncio.Queue()
737 self._event_subscribers.append(queue)
738 try:
739 while True:
740 event = await queue.get()
741 yield event
742 finally:
743 self._event_subscribers.remove(queue)
745 async def _notify_tool_added(self, tool: DbTool) -> None:
746 """
747 Notify subscribers of tool addition.
749 Args:
750 tool: Tool added
751 """
752 event = {
753 "type": "tool_added",
754 "data": {
755 "id": tool.id,
756 "name": tool.name,
757 "url": tool.url,
758 "description": tool.description,
759 "is_active": tool.is_active,
760 },
761 "timestamp": datetime.utcnow().isoformat(),
762 }
763 await self._publish_event(event)
765 async def _notify_tool_removed(self, tool: DbTool) -> None:
766 """
767 Notify subscribers of tool removal (soft delete/deactivation).
769 Args:
770 tool: Tool removed
771 """
772 event = {
773 "type": "tool_removed",
774 "data": {"id": tool.id, "name": tool.name, "is_active": False},
775 "timestamp": datetime.utcnow().isoformat(),
776 }
777 await self._publish_event(event)
779 async def _publish_event(self, event: Dict[str, Any]) -> None:
780 """
781 Publish event to all subscribers.
783 Args:
784 event: Event to publish
785 """
786 for queue in self._event_subscribers:
787 await queue.put(event)
789 async def _validate_tool_url(self, url: str) -> None:
790 """Validate tool URL is accessible.
792 Args:
793 url: URL to validate.
795 Raises:
796 ToolValidationError: If URL validation fails.
797 """
798 try:
799 response = await self._http_client.get(url)
800 response.raise_for_status()
801 except Exception as e:
802 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}")
804 async def _check_tool_health(self, tool: DbTool) -> bool:
805 """Check if tool endpoint is healthy.
807 Args:
808 tool: Tool to check.
810 Returns:
811 True if tool is healthy.
812 """
813 try:
814 response = await self._http_client.get(tool.url)
815 return response.is_success
816 except Exception:
817 return False
819 async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]:
820 """Generate tool events for SSE.
822 Yields:
823 Tool events.
824 """
825 queue: asyncio.Queue = asyncio.Queue()
826 self._event_subscribers.append(queue)
827 try:
828 while True:
829 event = await queue.get()
830 yield event
831 finally:
832 self._event_subscribers.remove(queue)
834 # --- Metrics ---
835 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]:
836 """
837 Aggregate metrics for all tool invocations.
839 Args:
840 db: Database session
842 Returns:
843 A dictionary with keys:
844 - total_executions
845 - successful_executions
846 - failed_executions
847 - failure_rate
848 - min_response_time
849 - max_response_time
850 - avg_response_time
851 - last_execution_time
852 """
854 total = db.execute(select(func.count(ToolMetric.id))).scalar() or 0 # pylint: disable=not-callable
855 successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success)).scalar() or 0 # pylint: disable=not-callable
856 failed = db.execute(select(func.count(ToolMetric.id)).where(not_(ToolMetric.is_success))).scalar() or 0 # pylint: disable=not-callable
857 failure_rate = failed / total if total > 0 else 0.0
858 min_rt = db.execute(select(func.min(ToolMetric.response_time))).scalar()
859 max_rt = db.execute(select(func.max(ToolMetric.response_time))).scalar()
860 avg_rt = db.execute(select(func.avg(ToolMetric.response_time))).scalar()
861 last_time = db.execute(select(func.max(ToolMetric.timestamp))).scalar()
863 return {
864 "total_executions": total,
865 "successful_executions": successful,
866 "failed_executions": failed,
867 "failure_rate": failure_rate,
868 "min_response_time": min_rt,
869 "max_response_time": max_rt,
870 "avg_response_time": avg_rt,
871 "last_execution_time": last_time,
872 }
874 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None:
875 """
876 Reset metrics for tool invocations.
878 If tool_id is provided, only the metrics for that specific tool will be deleted.
879 Otherwise, all tool metrics will be deleted (global reset).
881 Args:
882 db (Session): The SQLAlchemy database session.
883 tool_id (Optional[int]): Specific tool ID to reset metrics for.
884 """
886 if tool_id:
887 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id))
888 else:
889 db.execute(delete(ToolMetric))
890 db.commit()