Source code for MED3pa.models.regression_metrics

"""
The ``regression_metrics.py`` module defines the ``RegressionEvaluationMetrics`` class, 
that contains various regression metrics that can be used to assess the model's performance. 
"""
from typing import List
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

from .abstract_metrics import EvaluationMetric


[docs]class RegressionEvaluationMetrics(EvaluationMetric): """ A class to compute various regression evaluation metrics. """
[docs] @staticmethod def mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: """ Calculate the Mean Squared Error (MSE). Args: y_true (np.ndarray): True values. y_pred (np.ndarray): Predicted values. sample_weight (np.ndarray, optional): Sample weights. Returns: float: Mean Squared Error. """ if y_true.size == 0 or y_pred.size == 0: return None return mean_squared_error(y_true, y_pred, sample_weight=sample_weight)
[docs] @staticmethod def root_mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: """ Calculate the Root Mean Squared Error (RMSE). Args: y_true (np.ndarray): True values. y_pred (np.ndarray): Predicted values. sample_weight (np.ndarray, optional): Sample weights. Returns: float: Root Mean Squared Error. """ if y_true.size == 0 or y_pred.size == 0: return None return np.sqrt(mean_squared_error(y_true, y_pred, sample_weight=sample_weight))
[docs] @staticmethod def mean_absolute_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: """ Calculate the Mean Absolute Error (MAE). Args: y_true (np.ndarray): True values. y_pred (np.ndarray): Predicted values. sample_weight (np.ndarray, optional): Sample weights. Returns: float: Mean Absolute Error. """ if y_true.size == 0 or y_pred.size == 0: return None return mean_absolute_error(y_true, y_pred, sample_weight=sample_weight)
[docs] @staticmethod def r2_score(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: """ Calculate the R-squared (R2) score. Args: y_true (np.ndarray): True values. y_pred (np.ndarray): Predicted values. sample_weight (np.ndarray, optional): Sample weights. Returns: float: R-squared score. """ if y_true.size == 0 or y_pred.size == 0: return None return r2_score(y_true, y_pred, sample_weight=sample_weight)
[docs] @classmethod def get_metric(cls, metric_name: str): """ Get the metric function based on the metric name. Args: metric_name (str): The name of the metric. Returns: function: The function corresponding to the metric. """ metrics_mappings = { 'MSE': cls.mean_squared_error, 'RMSE': cls.root_mean_squared_error, 'MAE': cls.mean_absolute_error, 'R2': cls.r2_score } if metric_name == '': return list(metrics_mappings.keys()) else: metric_function = metrics_mappings.get(metric_name) return metric_function
[docs] @classmethod def supported_metrics(cls) -> List[str]: """ Get a list of supported classification metrics. Returns: list: A list of supported classification metrics. """ return cls.get_metric()