Coverage for mcpgateway/config.py: 62%
189 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 12:53 +0100
1# -*- coding: utf-8 -*-
2"""MCP Gateway Configuration.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module defines configuration settings for the MCP Gateway using Pydantic.
9It loads configuration from environment variables with sensible defaults.
11Environment variables:
12- APP_NAME: Gateway name (default: "MCP_Gateway")
13- HOST: Host to bind to (default: "127.0.0.1")
14- PORT: Port to listen on (default: 4444)
15- DATABASE_URL: SQLite database URL (default: "sqlite:///./mcp.db")
16- BASIC_AUTH_USER: Admin username (default: "admin")
17- BASIC_AUTH_PASSWORD: Admin password (default: "changeme")
18- LOG_LEVEL: Logging level (default: "INFO")
19- SKIP_SSL_VERIFY: Disable SSL verification (default: False)
20- AUTH_REQUIRED: Require authentication (default: True)
21- TRANSPORT_TYPE: Transport mechanisms (default: "all")
22- FEDERATION_ENABLED: Enable gateway federation (default: True)
23- FEDERATION_DISCOVERY: Enable auto-discovery (default: False)
24- FEDERATION_PEERS: List of peer gateway URLs (default: [])
25- RESOURCE_CACHE_SIZE: Max cached resources (default: 1000)
26- RESOURCE_CACHE_TTL: Cache TTL in seconds (default: 3600)
27- TOOL_TIMEOUT: Tool invocation timeout (default: 60)
28- PROMPT_CACHE_SIZE: Max cached prompts (default: 100)
29- HEALTH_CHECK_INTERVAL: Gateway health check interval (default: 60)
30"""
32import json
33from functools import lru_cache
34from importlib.resources import files
35from pathlib import Path
36from typing import Annotated, Any, Dict, List, Optional, Set, Union
38import jq
39from fastapi import HTTPException
40from jsonpath_ng.ext import parse
41from jsonpath_ng.jsonpath import JSONPath
42from pydantic import Field, field_validator
43from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
46class Settings(BaseSettings):
47 """MCP Gateway configuration settings."""
49 # Basic Settings
50 app_name: str = Field("MCP_Gateway", env="APP_NAME")
51 host: str = Field("127.0.0.1", env="HOST")
52 port: int = Field(4444, env="PORT")
53 database_url: str = "sqlite:///./mcp.db"
54 templates_dir: Path = Path("mcpgateway/templates")
55 # Absolute paths resolved at import-time (still override-able via env vars)
56 templates_dir: Path = Field(
57 default=files("mcpgateway") / "templates",
58 env="TEMPLATES_DIR",
59 )
60 static_dir: Path = Field(
61 default=files("mcpgateway") / "static",
62 env="STATIC_DIR",
63 )
64 app_root_path: str = ""
66 # Protocol
67 protocol_version: str = "2025-03-26"
69 # Authentication
70 basic_auth_user: str = "admin"
71 basic_auth_password: str = "changeme"
72 jwt_secret_key: str = "my-test-key"
73 jwt_algorithm: str = "HS256"
74 auth_required: bool = True
75 token_expiry: int = 10080 # minutes
77 # Encryption key phrase for auth storage
78 auth_encryption_secret: str = "my-test-salt"
80 # UI/Admin Feature Flags
81 mcpgateway_ui_enabled: bool = True
82 mcpgateway_admin_api_enabled: bool = True
84 # Security
85 skip_ssl_verify: bool = False
87 # For allowed_origins, strip '' to ensure we're passing on valid JSON via env
88 # Tell pydantic *not* to touch this env var – our validator will.
89 allowed_origins: Annotated[Set[str], NoDecode] = {
90 "http://localhost",
91 "http://localhost:4444",
92 }
94 @field_validator("allowed_origins", mode="before")
95 @classmethod
96 def _parse_allowed_origins(cls, v):
97 if isinstance(v, str):
98 v = v.strip()
99 if v[:1] in "\"'" and v[-1:] == v[:1]: # strip 1 outer quote pair 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 v = v[1:-1]
101 try:
102 parsed = set(json.loads(v))
103 except json.JSONDecodeError:
104 parsed = {s.strip() for s in v.split(",") if s.strip()}
105 return parsed
106 return set(v)
108 # Logging
109 log_level: str = "INFO"
110 log_format: str = "json" # json or text
111 log_file: Optional[Path] = None
113 # Transport
114 transport_type: str = "all" # http, ws, sse, all
115 websocket_ping_interval: int = 30 # seconds
116 sse_retry_timeout: int = 5000 # milliseconds
118 # Federation
119 federation_enabled: bool = True
120 federation_discovery: bool = False
122 # For federation_peers strip out quotes to ensure we're passing valid JSON via env
123 federation_peers: Annotated[List[str], NoDecode] = []
125 # Lock file path for initializing gateway service initialize
126 lock_file_path: str = "/tmp/gateway_init.done"
128 @field_validator("federation_peers", mode="before")
129 @classmethod
130 def _parse_federation_peers(cls, v):
131 if isinstance(v, str):
132 v = v.strip()
133 if v[:1] in "\"'" and v[-1:] == v[:1]: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true
134 v = v[1:-1]
135 try:
136 peers = json.loads(v)
137 except json.JSONDecodeError:
138 peers = [s.strip() for s in v.split(",") if s.strip()]
139 return peers
140 return list(v)
142 federation_timeout: int = 30 # seconds
143 federation_sync_interval: int = 300 # seconds
145 # Resources
146 resource_cache_size: int = 1000
147 resource_cache_ttl: int = 3600 # seconds
148 max_resource_size: int = 10 * 1024 * 1024 # 10MB
149 allowed_mime_types: Set[str] = {
150 "text/plain",
151 "text/markdown",
152 "text/html",
153 "application/json",
154 "application/xml",
155 "image/png",
156 "image/jpeg",
157 "image/gif",
158 }
160 # Tools
161 tool_timeout: int = 60 # seconds
162 max_tool_retries: int = 3
163 tool_rate_limit: int = 100 # requests per minute
164 tool_concurrent_limit: int = 10
166 # Prompts
167 prompt_cache_size: int = 100
168 max_prompt_size: int = 100 * 1024 # 100KB
169 prompt_render_timeout: int = 10 # seconds
171 # Health Checks
172 health_check_interval: int = 60 # seconds
173 health_check_timeout: int = 10 # seconds
174 unhealthy_threshold: int = 10
176 filelock_path: str = "tmp/gateway_service_leader.lock"
178 # Default Roots
179 default_roots: List[str] = []
181 # Database
182 db_pool_size: int = 200
183 db_max_overflow: int = 10
184 db_pool_timeout: int = 30
185 db_pool_recycle: int = 3600
187 # Cache
188 cache_type: str = "database" # memory or redis or database
189 redis_url: Optional[str] = "redis://localhost:6379/0"
190 cache_prefix: str = "mcpgw:"
191 session_ttl: int = 3600
192 message_ttl: int = 600
194 # streamable http transport
195 use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store
196 json_response_enabled: bool = True # Enable JSON responses instead of SSE streams
198 # Development
199 dev_mode: bool = False
200 reload: bool = False
201 debug: bool = False
203 model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore")
205 @property
206 def api_key(self) -> str:
207 """Generate API key from auth credentials.
209 Returns:
210 str: API key string in the format "username:password".
211 """
212 return f"{self.basic_auth_user}:{self.basic_auth_password}"
214 @property
215 def supports_http(self) -> bool:
216 """Check if HTTP transport is enabled.
218 Returns:
219 bool: True if HTTP transport is enabled, False otherwise.
220 """
221 return self.transport_type in ["http", "all"]
223 @property
224 def supports_websocket(self) -> bool:
225 """Check if WebSocket transport is enabled.
227 Returns:
228 bool: True if WebSocket transport is enabled, False otherwise.
229 """
230 return self.transport_type in ["ws", "all"]
232 @property
233 def supports_sse(self) -> bool:
234 """Check if SSE transport is enabled.
236 Returns:
237 bool: True if SSE transport is enabled, False otherwise.
238 """
239 return self.transport_type in ["sse", "all"]
241 @property
242 def database_settings(self) -> dict:
243 """Get SQLAlchemy database settings.
245 Returns:
246 dict: Dictionary containing SQLAlchemy database configuration options.
247 """
248 return {
249 "pool_size": self.db_pool_size,
250 "max_overflow": self.db_max_overflow,
251 "pool_timeout": self.db_pool_timeout,
252 "pool_recycle": self.db_pool_recycle,
253 "connect_args": {"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
254 }
256 @property
257 def cors_settings(self) -> dict:
258 """Get CORS settings.
260 Returns:
261 dict: Dictionary containing CORS configuration options.
262 """
263 return (
264 {
265 "allow_origins": list(self.allowed_origins),
266 "allow_credentials": True,
267 "allow_methods": ["*"],
268 "allow_headers": ["*"],
269 }
270 if self.cors_enabled
271 else {}
272 )
274 def validate_transport(self) -> None:
275 """Validate transport configuration.
277 Raises:
278 ValueError: If the transport type is not one of the valid options.
279 """
280 valid_types = {"http", "ws", "sse", "all"}
281 if self.transport_type not in valid_types: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true
282 raise ValueError(f"Invalid transport type. Must be one of: {valid_types}")
284 def validate_database(self) -> None:
285 """Validate database configuration."""
286 if self.database_url.startswith("sqlite"): 286 ↛ exitline 286 didn't return from function 'validate_database' because the condition on line 286 was always true
287 db_path = Path(self.database_url.replace("sqlite:///", ""))
288 db_dir = db_path.parent
289 if not db_dir.exists(): 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true
290 db_dir.mkdir(parents=True)
293def extract_using_jq(data, jq_filter=""):
294 """
295 Extracts data from a given input (string, dict, or list) using a jq filter string.
297 Args:
298 data (str, dict, list): The input JSON data. Can be a string, dict, or list.
299 jq_filter (str): The jq filter string to extract the desired data.
301 Returns:
302 The result of applying the jq filter to the input data.
303 """
304 if jq_filter == "": 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true
305 return data
306 if isinstance(data, str):
307 # If the input is a string, parse it as JSON
308 try:
309 data = json.loads(data)
310 except json.JSONDecodeError:
311 return ["Invalid JSON string provided."]
313 elif not isinstance(data, (dict, list)):
314 # If the input is not a string, dict, or list, raise an error
315 return ["Input data must be a JSON string, dictionary, or list."]
317 # Apply the jq filter to the data
318 try:
319 # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function.
320 # pylint: disable=c-extension-no-member
321 result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list)
322 if result == [None]:
323 result = "Error applying jsonpath filter"
324 except Exception as e:
325 message = "Error applying jsonpath filter: " + str(e)
326 return message
328 return result
331def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]:
332 """
333 Applies the given JSONPath expression and mappings to the data.
334 Only return data that is required by the user dynamically.
336 Args:
337 data: The JSON data to query.
338 jsonpath: The JSONPath expression to apply.
339 mappings: Optional dictionary of mappings where keys are new field names
340 and values are JSONPath expressions.
342 Returns:
343 Union[List, Dict]: A list (or mapped list) or a Dict of extracted data.
345 Raises:
346 HTTPException: If there's an error parsing or executing the JSONPath expressions.
347 """
348 if not jsonpath:
349 jsonpath = "$[*]"
351 try:
352 main_expr: JSONPath = parse(jsonpath)
353 except Exception as e:
354 raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}")
356 try:
357 main_matches = main_expr.find(data)
358 except Exception as e:
359 raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}")
361 results = [match.value for match in main_matches]
363 if mappings:
364 mapped_results = []
365 for item in results:
366 mapped_item = {}
367 for new_key, mapping_expr_str in mappings.items():
368 try:
369 mapping_expr = parse(mapping_expr_str)
370 except Exception as e:
371 raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}")
372 try:
373 mapping_matches = mapping_expr.find(item)
374 except Exception as e:
375 raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}")
376 if not mapping_matches:
377 mapped_item[new_key] = None
378 elif len(mapping_matches) == 1:
379 mapped_item[new_key] = mapping_matches[0].value
380 else:
381 mapped_item[new_key] = [m.value for m in mapping_matches]
382 mapped_results.append(mapped_item)
383 results = mapped_results
385 if len(results) == 1 and isinstance(results[0], dict):
386 return results[0]
387 return results
390@lru_cache()
391def get_settings() -> Settings:
392 """Get cached settings instance.
394 Returns:
395 Settings: A cached instance of the Settings class.
396 """
397 # Instantiate a fresh Pydantic Settings object,
398 # loading from env vars or .env exactly once.
399 cfg = Settings()
400 # Validate that transport_type is correct; will
401 # raise if mis-configured.
402 cfg.validate_transport()
403 # Ensure sqlite DB directories exist if needed.
404 cfg.validate_database()
405 # Return the one-and-only Settings instance (cached).
406 return cfg
409# Create settings instance
410settings = get_settings()