Coverage for mcpgateway/services/completion_service.py: 92%
70 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
1# -*- coding: utf-8 -*-
2"""Completion Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements argument completion according to the MCP specification.
9It handles completion suggestions for prompt arguments and resource URIs.
10"""
12import logging
13from typing import Any, Dict, List
15from sqlalchemy import select
16from sqlalchemy.orm import Session
18from mcpgateway.db import Prompt as DbPrompt
19from mcpgateway.db import Resource as DbResource
20from mcpgateway.types import CompleteResult
22logger = logging.getLogger(__name__)
25class CompletionError(Exception):
26 """Base class for completion errors."""
29class CompletionService:
30 """MCP completion service.
32 Handles argument completion for:
33 - Prompt arguments based on schema
34 - Resource URIs with templates
35 - Custom completion sources
36 """
38 def __init__(self):
39 """Initialize completion service."""
40 self._custom_completions: Dict[str, List[str]] = {}
42 async def initialize(self) -> None:
43 """Initialize completion service."""
44 logger.info("Initializing completion service")
46 async def shutdown(self) -> None:
47 """Shutdown completion service."""
48 logger.info("Shutting down completion service")
49 self._custom_completions.clear()
51 async def handle_completion(self, db: Session, request: Dict[str, Any]) -> CompleteResult:
52 """Handle completion request.
54 Args:
55 db: Database session
56 request: Completion request
58 Returns:
59 Completion result with suggestions
61 Raises:
62 CompletionError: If completion fails
63 """
64 try:
65 # Get reference and argument info
66 ref = request.get("ref", {})
67 ref_type = ref.get("type")
68 arg = request.get("argument", {})
69 arg_name = arg.get("name")
70 arg_value = arg.get("value", "")
72 if not ref_type or not arg_name:
73 raise CompletionError("Missing reference type or argument name")
75 # Handle different reference types
76 if ref_type == "ref/prompt": 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value)
78 elif ref_type == "ref/resource": 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 result = await self._complete_resource_uri(db, ref, arg_value)
80 else:
81 raise CompletionError(f"Invalid reference type: {ref_type}")
83 return result
85 except Exception as e:
86 logger.error(f"Completion error: {e}")
87 raise CompletionError(str(e))
89 async def _complete_prompt_argument(self, db: Session, ref: Dict[str, Any], arg_name: str, arg_value: str) -> CompleteResult:
90 """Complete prompt argument value.
92 Args:
93 db: Database session
94 ref: Prompt reference
95 arg_name: Argument name
96 arg_value: Current argument value
98 Returns:
99 Completion suggestions
101 Raises:
102 CompletionError: If URI template is missing
103 """
104 # Get prompt
105 prompt_name = ref.get("name")
106 if not prompt_name:
107 raise CompletionError("Missing prompt name")
109 prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.is_active)).scalar_one_or_none()
111 if not prompt:
112 raise CompletionError(f"Prompt not found: {prompt_name}")
114 # Find argument in schema
115 arg_schema = None
116 for arg in prompt.argument_schema.get("properties", {}).values():
117 if arg.get("name") == arg_name:
118 arg_schema = arg
119 break
121 if not arg_schema:
122 raise CompletionError(f"Argument not found: {arg_name}")
124 # Get enum values if defined
125 if "enum" in arg_schema:
126 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()]
127 return CompleteResult(
128 completion={
129 "values": values[:100],
130 "total": len(values),
131 "hasMore": len(values) > 100,
132 }
133 )
135 # Check custom completions
136 if arg_name in self._custom_completions:
137 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()]
138 return CompleteResult(
139 completion={
140 "values": values[:100],
141 "total": len(values),
142 "hasMore": len(values) > 100,
143 }
144 )
146 # No completions available
147 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False})
149 async def _complete_resource_uri(self, db: Session, ref: Dict[str, Any], arg_value: str) -> CompleteResult:
150 """Complete resource URI.
152 Args:
153 db: Database session
154 ref: Resource reference
155 arg_value: Current URI value
157 Returns:
158 URI completion suggestions
160 Raises:
161 CompletionError: If URI template is missing
162 """
163 # Get base URI template
164 uri_template = ref.get("uri")
165 if not uri_template:
166 raise CompletionError("Missing URI template")
168 # List matching resources
169 resources = db.execute(select(DbResource).where(DbResource.is_active)).scalars().all()
171 # Filter by URI pattern
172 matches = []
173 for resource in resources:
174 if arg_value.lower() in resource.uri.lower():
175 matches.append(resource.uri)
177 return CompleteResult(
178 completion={
179 "values": matches[:100],
180 "total": len(matches),
181 "hasMore": len(matches) > 100,
182 }
183 )
185 def register_completions(self, arg_name: str, values: List[str]) -> None:
186 """Register custom completion values.
188 Args:
189 arg_name: Argument name
190 values: Completion values
191 """
192 self._custom_completions[arg_name] = list(values)
194 def unregister_completions(self, arg_name: str) -> None:
195 """Unregister custom completion values.
197 Args:
198 arg_name: Argument name
199 """
200 self._custom_completions.pop(arg_name, None)