Coverage for mcpgateway/config.py: 62%

189 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 12:53 +0100

1# -*- coding: utf-8 -*- 

2"""MCP Gateway Configuration. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module defines configuration settings for the MCP Gateway using Pydantic. 

9It loads configuration from environment variables with sensible defaults. 

10 

11Environment variables: 

12- APP_NAME: Gateway name (default: "MCP_Gateway") 

13- HOST: Host to bind to (default: "127.0.0.1") 

14- PORT: Port to listen on (default: 4444) 

15- DATABASE_URL: SQLite database URL (default: "sqlite:///./mcp.db") 

16- BASIC_AUTH_USER: Admin username (default: "admin") 

17- BASIC_AUTH_PASSWORD: Admin password (default: "changeme") 

18- LOG_LEVEL: Logging level (default: "INFO") 

19- SKIP_SSL_VERIFY: Disable SSL verification (default: False) 

20- AUTH_REQUIRED: Require authentication (default: True) 

21- TRANSPORT_TYPE: Transport mechanisms (default: "all") 

22- FEDERATION_ENABLED: Enable gateway federation (default: True) 

23- FEDERATION_DISCOVERY: Enable auto-discovery (default: False) 

24- FEDERATION_PEERS: List of peer gateway URLs (default: []) 

25- RESOURCE_CACHE_SIZE: Max cached resources (default: 1000) 

26- RESOURCE_CACHE_TTL: Cache TTL in seconds (default: 3600) 

27- TOOL_TIMEOUT: Tool invocation timeout (default: 60) 

28- PROMPT_CACHE_SIZE: Max cached prompts (default: 100) 

29- HEALTH_CHECK_INTERVAL: Gateway health check interval (default: 60) 

30""" 

31 

32import json 

33from functools import lru_cache 

34from importlib.resources import files 

35from pathlib import Path 

36from typing import Annotated, Any, Dict, List, Optional, Set, Union 

37 

38import jq 

39from fastapi import HTTPException 

40from jsonpath_ng.ext import parse 

41from jsonpath_ng.jsonpath import JSONPath 

42from pydantic import Field, field_validator 

43from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict 

44 

45 

46class Settings(BaseSettings): 

47 """MCP Gateway configuration settings.""" 

48 

49 # Basic Settings 

50 app_name: str = Field("MCP_Gateway", env="APP_NAME") 

51 host: str = Field("127.0.0.1", env="HOST") 

52 port: int = Field(4444, env="PORT") 

53 database_url: str = "sqlite:///./mcp.db" 

54 templates_dir: Path = Path("mcpgateway/templates") 

55 # Absolute paths resolved at import-time (still override-able via env vars) 

56 templates_dir: Path = Field( 

57 default=files("mcpgateway") / "templates", 

58 env="TEMPLATES_DIR", 

59 ) 

60 static_dir: Path = Field( 

61 default=files("mcpgateway") / "static", 

62 env="STATIC_DIR", 

63 ) 

64 app_root_path: str = "" 

65 

66 # Protocol 

67 protocol_version: str = "2025-03-26" 

68 

69 # Authentication 

70 basic_auth_user: str = "admin" 

71 basic_auth_password: str = "changeme" 

72 jwt_secret_key: str = "my-test-key" 

73 jwt_algorithm: str = "HS256" 

74 auth_required: bool = True 

75 token_expiry: int = 10080 # minutes 

76 

77 # Encryption key phrase for auth storage 

78 auth_encryption_secret: str = "my-test-salt" 

79 

80 # UI/Admin Feature Flags 

81 mcpgateway_ui_enabled: bool = True 

82 mcpgateway_admin_api_enabled: bool = True 

83 

84 # Security 

85 skip_ssl_verify: bool = False 

86 

87 # For allowed_origins, strip '' to ensure we're passing on valid JSON via env 

88 # Tell pydantic *not* to touch this env var – our validator will. 

89 allowed_origins: Annotated[Set[str], NoDecode] = { 

90 "http://localhost", 

91 "http://localhost:4444", 

92 } 

93 

94 @field_validator("allowed_origins", mode="before") 

95 @classmethod 

96 def _parse_allowed_origins(cls, v): 

97 if isinstance(v, str): 

98 v = v.strip() 

99 if v[:1] in "\"'" and v[-1:] == v[:1]: # strip 1 outer quote pair 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 v = v[1:-1] 

101 try: 

102 parsed = set(json.loads(v)) 

103 except json.JSONDecodeError: 

104 parsed = {s.strip() for s in v.split(",") if s.strip()} 

105 return parsed 

106 return set(v) 

107 

108 # Logging 

109 log_level: str = "INFO" 

110 log_format: str = "json" # json or text 

111 log_file: Optional[Path] = None 

112 

113 # Transport 

114 transport_type: str = "all" # http, ws, sse, all 

115 websocket_ping_interval: int = 30 # seconds 

116 sse_retry_timeout: int = 5000 # milliseconds 

117 

118 # Federation 

119 federation_enabled: bool = True 

120 federation_discovery: bool = False 

121 

122 # For federation_peers strip out quotes to ensure we're passing valid JSON via env 

123 federation_peers: Annotated[List[str], NoDecode] = [] 

124 

125 # Lock file path for initializing gateway service initialize 

126 lock_file_path: str = "/tmp/gateway_init.done" 

127 

128 @field_validator("federation_peers", mode="before") 

129 @classmethod 

130 def _parse_federation_peers(cls, v): 

131 if isinstance(v, str): 

132 v = v.strip() 

133 if v[:1] in "\"'" and v[-1:] == v[:1]: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 v = v[1:-1] 

135 try: 

136 peers = json.loads(v) 

137 except json.JSONDecodeError: 

138 peers = [s.strip() for s in v.split(",") if s.strip()] 

139 return peers 

140 return list(v) 

141 

142 federation_timeout: int = 30 # seconds 

143 federation_sync_interval: int = 300 # seconds 

144 

145 # Resources 

146 resource_cache_size: int = 1000 

147 resource_cache_ttl: int = 3600 # seconds 

148 max_resource_size: int = 10 * 1024 * 1024 # 10MB 

149 allowed_mime_types: Set[str] = { 

150 "text/plain", 

151 "text/markdown", 

152 "text/html", 

153 "application/json", 

154 "application/xml", 

155 "image/png", 

156 "image/jpeg", 

157 "image/gif", 

158 } 

159 

160 # Tools 

161 tool_timeout: int = 60 # seconds 

162 max_tool_retries: int = 3 

163 tool_rate_limit: int = 100 # requests per minute 

164 tool_concurrent_limit: int = 10 

165 

166 # Prompts 

167 prompt_cache_size: int = 100 

168 max_prompt_size: int = 100 * 1024 # 100KB 

169 prompt_render_timeout: int = 10 # seconds 

170 

171 # Health Checks 

172 health_check_interval: int = 60 # seconds 

173 health_check_timeout: int = 10 # seconds 

174 unhealthy_threshold: int = 10 

175 

176 filelock_path: str = "tmp/gateway_service_leader.lock" 

177 

178 # Default Roots 

179 default_roots: List[str] = [] 

180 

181 # Database 

182 db_pool_size: int = 200 

183 db_max_overflow: int = 10 

184 db_pool_timeout: int = 30 

185 db_pool_recycle: int = 3600 

186 

187 # Cache 

188 cache_type: str = "database" # memory or redis or database 

189 redis_url: Optional[str] = "redis://localhost:6379/0" 

190 cache_prefix: str = "mcpgw:" 

191 session_ttl: int = 3600 

192 message_ttl: int = 600 

193 

194 # streamable http transport 

195 use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store 

196 json_response_enabled: bool = True # Enable JSON responses instead of SSE streams 

197 

198 # Development 

199 dev_mode: bool = False 

200 reload: bool = False 

201 debug: bool = False 

202 

203 model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore") 

204 

205 @property 

206 def api_key(self) -> str: 

207 """Generate API key from auth credentials. 

208 

209 Returns: 

210 str: API key string in the format "username:password". 

211 """ 

212 return f"{self.basic_auth_user}:{self.basic_auth_password}" 

213 

214 @property 

215 def supports_http(self) -> bool: 

216 """Check if HTTP transport is enabled. 

217 

218 Returns: 

219 bool: True if HTTP transport is enabled, False otherwise. 

220 """ 

221 return self.transport_type in ["http", "all"] 

222 

223 @property 

224 def supports_websocket(self) -> bool: 

225 """Check if WebSocket transport is enabled. 

226 

227 Returns: 

228 bool: True if WebSocket transport is enabled, False otherwise. 

229 """ 

230 return self.transport_type in ["ws", "all"] 

231 

232 @property 

233 def supports_sse(self) -> bool: 

234 """Check if SSE transport is enabled. 

235 

236 Returns: 

237 bool: True if SSE transport is enabled, False otherwise. 

238 """ 

239 return self.transport_type in ["sse", "all"] 

240 

241 @property 

242 def database_settings(self) -> dict: 

243 """Get SQLAlchemy database settings. 

244 

245 Returns: 

246 dict: Dictionary containing SQLAlchemy database configuration options. 

247 """ 

248 return { 

249 "pool_size": self.db_pool_size, 

250 "max_overflow": self.db_max_overflow, 

251 "pool_timeout": self.db_pool_timeout, 

252 "pool_recycle": self.db_pool_recycle, 

253 "connect_args": {"check_same_thread": False} if self.database_url.startswith("sqlite") else {}, 

254 } 

255 

256 @property 

257 def cors_settings(self) -> dict: 

258 """Get CORS settings. 

259 

260 Returns: 

261 dict: Dictionary containing CORS configuration options. 

262 """ 

263 return ( 

264 { 

265 "allow_origins": list(self.allowed_origins), 

266 "allow_credentials": True, 

267 "allow_methods": ["*"], 

268 "allow_headers": ["*"], 

269 } 

270 if self.cors_enabled 

271 else {} 

272 ) 

273 

274 def validate_transport(self) -> None: 

275 """Validate transport configuration. 

276 

277 Raises: 

278 ValueError: If the transport type is not one of the valid options. 

279 """ 

280 valid_types = {"http", "ws", "sse", "all"} 

281 if self.transport_type not in valid_types: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true

282 raise ValueError(f"Invalid transport type. Must be one of: {valid_types}") 

283 

284 def validate_database(self) -> None: 

285 """Validate database configuration.""" 

286 if self.database_url.startswith("sqlite"): 286 ↛ exitline 286 didn't return from function 'validate_database' because the condition on line 286 was always true

287 db_path = Path(self.database_url.replace("sqlite:///", "")) 

288 db_dir = db_path.parent 

289 if not db_dir.exists(): 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true

290 db_dir.mkdir(parents=True) 

291 

292 

293def extract_using_jq(data, jq_filter=""): 

294 """ 

295 Extracts data from a given input (string, dict, or list) using a jq filter string. 

296 

297 Args: 

298 data (str, dict, list): The input JSON data. Can be a string, dict, or list. 

299 jq_filter (str): The jq filter string to extract the desired data. 

300 

301 Returns: 

302 The result of applying the jq filter to the input data. 

303 """ 

304 if jq_filter == "": 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true

305 return data 

306 if isinstance(data, str): 

307 # If the input is a string, parse it as JSON 

308 try: 

309 data = json.loads(data) 

310 except json.JSONDecodeError: 

311 return ["Invalid JSON string provided."] 

312 

313 elif not isinstance(data, (dict, list)): 

314 # If the input is not a string, dict, or list, raise an error 

315 return ["Input data must be a JSON string, dictionary, or list."] 

316 

317 # Apply the jq filter to the data 

318 try: 

319 # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function. 

320 # pylint: disable=c-extension-no-member 

321 result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list) 

322 if result == [None]: 

323 result = "Error applying jsonpath filter" 

324 except Exception as e: 

325 message = "Error applying jsonpath filter: " + str(e) 

326 return message 

327 

328 return result 

329 

330 

331def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]: 

332 """ 

333 Applies the given JSONPath expression and mappings to the data. 

334 Only return data that is required by the user dynamically. 

335 

336 Args: 

337 data: The JSON data to query. 

338 jsonpath: The JSONPath expression to apply. 

339 mappings: Optional dictionary of mappings where keys are new field names 

340 and values are JSONPath expressions. 

341 

342 Returns: 

343 Union[List, Dict]: A list (or mapped list) or a Dict of extracted data. 

344 

345 Raises: 

346 HTTPException: If there's an error parsing or executing the JSONPath expressions. 

347 """ 

348 if not jsonpath: 

349 jsonpath = "$[*]" 

350 

351 try: 

352 main_expr: JSONPath = parse(jsonpath) 

353 except Exception as e: 

354 raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}") 

355 

356 try: 

357 main_matches = main_expr.find(data) 

358 except Exception as e: 

359 raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}") 

360 

361 results = [match.value for match in main_matches] 

362 

363 if mappings: 

364 mapped_results = [] 

365 for item in results: 

366 mapped_item = {} 

367 for new_key, mapping_expr_str in mappings.items(): 

368 try: 

369 mapping_expr = parse(mapping_expr_str) 

370 except Exception as e: 

371 raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}") 

372 try: 

373 mapping_matches = mapping_expr.find(item) 

374 except Exception as e: 

375 raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}") 

376 if not mapping_matches: 

377 mapped_item[new_key] = None 

378 elif len(mapping_matches) == 1: 

379 mapped_item[new_key] = mapping_matches[0].value 

380 else: 

381 mapped_item[new_key] = [m.value for m in mapping_matches] 

382 mapped_results.append(mapped_item) 

383 results = mapped_results 

384 

385 if len(results) == 1 and isinstance(results[0], dict): 

386 return results[0] 

387 return results 

388 

389 

390@lru_cache() 

391def get_settings() -> Settings: 

392 """Get cached settings instance. 

393 

394 Returns: 

395 Settings: A cached instance of the Settings class. 

396 """ 

397 # Instantiate a fresh Pydantic Settings object, 

398 # loading from env vars or .env exactly once. 

399 cfg = Settings() 

400 # Validate that transport_type is correct; will 

401 # raise if mis-configured. 

402 cfg.validate_transport() 

403 # Ensure sqlite DB directories exist if needed. 

404 cfg.validate_database() 

405 # Return the one-and-only Settings instance (cached). 

406 return cfg 

407 

408 

409# Create settings instance 

410settings = get_settings()