Module statkit.decision

Evaluate models using decision curve analysis.

Expand source code
"""Evaluate models using decision curve analysis."""
from typing import Literal, Optional, Union

from matplotlib import pyplot as plt
from numpy import array, divide, linspace, r_, zeros_like
from numpy.typing import NDArray
from pandas import Series
from sklearn.utils import column_or_1d


def _binary_classification_thresholds(y_true, y_proba, thresholds):
    """Compute false and true positives for given probability thresholds."""
    y_true = column_or_1d(y_true)
    y_proba = column_or_1d(y_proba)
    tps = []
    fps = []
    positive = y_true == 1
    negative = y_true == 0
    for t in thresholds:
        y_pred = (y_proba >= t).astype(int)
        TP = (y_pred == y_true) & positive
        FP = (y_pred != y_true) & negative
        tps.append(sum(TP))
        fps.append(sum(FP))
    return array(fps), array(tps)


def net_benefit(
    y_true: Union[Series, NDArray],
    y_pred: Union[Series, NDArray],
    n_thresholds: int = 100,
    action: bool = True,
):
    """Net benefit of taking an action using a model's predictions.

    Args:
        y_true: Binary ground truth label (1: positive, 0: negative class).
        y_pred: Probability of positive class label.
        n_thresholds: Number of x coordinates (the probability thresholds).
        action: When `True` (`False`), estimate net benefit of taking (not taking) an
            action/intervention/treatment.

    Returns:
        thresholds: Probability threshold of prediction a positive class.
        benefit: The net benefit corresponding to the thresholds.

    References:
        [1]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net
        benefit, relationships to ROC curve analysis, and application to case-control
        studies." BMC medical informatics and decision making 11.1 (2011): 1–9.
    """
    if set(y_true.astype(int)) != {0, 1}:
        raise ValueError(
            "Decision curve analysis only supports binary classification (with labels 1 and 0)."
        )

    thresholds = linspace(0, 1, num=n_thresholds)

    if action:
        # fps, tps, thresholds = _binary_clf_curve(y_true, y_pred, pos_label=1)
        fps, tps = _binary_classification_thresholds(y_true, y_pred, thresholds)
    else:
        # Invert 0<-->1 so that true positives are true negatives, and false positives
        # are false negatives.
        fps, tps = _binary_classification_thresholds(1 - y_true, 1 - y_pred, thresholds)

    N = len(y_true)

    loss_over_profit = divide(
        thresholds, 1 - thresholds, where=thresholds < 1, out=zeros_like(thresholds)
    )
    benefit = tps / N - fps / N * loss_over_profit
    return thresholds, benefit


def net_benefit_oracle(y_true, action: bool = True) -> float:
    """Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor."""
    if action:
        return y_true.mean()
    return 1 - y_true.mean()


def net_benefit_action(y_true, threshold, action: bool = True):
    """Net benefit of always doing an action/intervention/treatment.

    Args:
        action: When `False`, invert positive label in `y_true`.
    """
    if action:
        return y_true.mean() - (1 - y_true.mean()) * threshold / (1 - threshold)
    return 1 - y_true.mean() - y_true.mean() * (1 - threshold) / threshold


def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100):
    """Net benefit combining both taking and not-taking action."""
    thresholds_action, benefit_action = net_benefit(
        y_true, y_pred, n_thresholds, action=True
    )
    _, benefit_no_action = net_benefit(y_true, y_pred, n_thresholds, action=False)
    return thresholds_action, benefit_action + benefit_no_action


class NetBenefitDisplay:
    """Net benefit decision curve analysis visualisation.

    Args:
        threshold_probability: Probability to dichotomise the predicted probability
            of the model.
        net_benefit: Net benefit of taking an action as a function of
            `threshold_probability`.
        oracle: The (constant) net benefit of a perfect predictor.
    """

    def __init__(
        self,
        threshold_probability,
        net_benefit,
        oracle: Optional[float] = None,
        estimator_name: Optional[str] = None,
    ):
        self.threshold_probability = threshold_probability
        self.net_benefit = net_benefit
        self.estimator_name = estimator_name
        self.oracle = oracle

    def plot(self, show_references: bool = True, ax=None):
        """
        Args:
            show_references: Show oracle (requires prevalence) and no
                action/treatment/intervention reference curves.
            ax: Optional axes object to plot on. If `None`, a new figure and axes is
                created.
        """
        if ax is None:
            fig, ax = plt.subplots()
        self.ax_ = ax
        self.figure_ = fig

        ax.plot(
            self.threshold_probability, self.net_benefit, "-", label=self.estimator_name
        )
        if show_references:
            ax.plot([0, 1], [0, 0], "-.", label="No action")
            if self.oracle is not None:
                ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle")

        ax.set(xlabel="Threshold probability", ylabel="Net benefit")
        ax.set_ylim([0, 1])
        ax.legend(loc="upper right", frameon=False)
        return self

    @classmethod
    def from_predictions(
        cls,
        y_true,
        y_pred,
        benefit: Literal["action", "noop", "overall"] = "action",
        name: Optional[str] = None,
        n_thresholds: int = 100,
        ax=None,
    ):
        """Make a net benefit plot from true and predicted labels.

        Args:
            y_true: Binary ground truth label (1: positive, 0: negative class).
            y_pred: Predicted class labels.
            benefit: Type of net benefit curve. `"action"`: net benefit of
                treatment/intervention/action; `"noop`: net benefit of no
                treatment/intervention/action; `"overall"`: overall net benefit (see
                `overall_net_benefit`).
            ax: Optional axes object to plot on. If `None`, a new figure and axes is
                created.

        Example:
            ```python
            from matplotlib import pyplot as plt
            from sklearn.datasets import make_blobs
            from sklearn.linear_model import LogisticRegression
            from sklearn.model_selection import train_test_split
            from statkit.decision import NetBenefitDisplay

            X, y = make_blobs(n_features=2, centers=2, cluster_std=0.5, random_state=5)
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
            clf = LogisticRegression().fit(X_train, y_train)
            y_pred = clf.predict_proba(X_test)[:, 1]

            NetBenefitDisplay.from_predictions(y_test, y_pred, name='model')
            plt.show()
            ```
        """
        if benefit == "action":
            oracle = net_benefit_oracle(y_true)
            thresholds, benefit = net_benefit(y_true, y_pred, n_thresholds)
        elif benefit == "noop":
            oracle = net_benefit_oracle(y_true, action=False)
            thresholds, benefit = net_benefit(
                y_true, y_pred, n_thresholds, action=False
            )
        elif benefit == "overall":
            oracle = 1.0
            thresholds, benefit = overall_net_benefit(y_true, y_pred, n_thresholds)

        return cls(
            thresholds,
            benefit,
            oracle,
            estimator_name=name,
        ).plot(ax=ax)

Functions

def net_benefit(y_true: Union[pandas.core.series.Series, numpy.ndarray[Any, numpy.dtype[+ScalarType]]], y_pred: Union[pandas.core.series.Series, numpy.ndarray[Any, numpy.dtype[+ScalarType]]], n_thresholds: int = 100, action: bool = True)

Net benefit of taking an action using a model's predictions.

Args

y_true
Binary ground truth label (1: positive, 0: negative class).
y_pred
Probability of positive class label.
n_thresholds
Number of x coordinates (the probability thresholds).
action
When True (False), estimate net benefit of taking (not taking) an action/intervention/treatment.

Returns

thresholds
Probability threshold of prediction a positive class.
benefit
The net benefit corresponding to the thresholds.

References

[1]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net benefit, relationships to ROC curve analysis, and application to case-control studies." BMC medical informatics and decision making 11.1 (2011): 1–9.

Expand source code
def net_benefit(
    y_true: Union[Series, NDArray],
    y_pred: Union[Series, NDArray],
    n_thresholds: int = 100,
    action: bool = True,
):
    """Net benefit of taking an action using a model's predictions.

    Args:
        y_true: Binary ground truth label (1: positive, 0: negative class).
        y_pred: Probability of positive class label.
        n_thresholds: Number of x coordinates (the probability thresholds).
        action: When `True` (`False`), estimate net benefit of taking (not taking) an
            action/intervention/treatment.

    Returns:
        thresholds: Probability threshold of prediction a positive class.
        benefit: The net benefit corresponding to the thresholds.

    References:
        [1]: Rousson-Zumbrunn. "Decision curve analysis revisited: overall net
        benefit, relationships to ROC curve analysis, and application to case-control
        studies." BMC medical informatics and decision making 11.1 (2011): 1–9.
    """
    if set(y_true.astype(int)) != {0, 1}:
        raise ValueError(
            "Decision curve analysis only supports binary classification (with labels 1 and 0)."
        )

    thresholds = linspace(0, 1, num=n_thresholds)

    if action:
        # fps, tps, thresholds = _binary_clf_curve(y_true, y_pred, pos_label=1)
        fps, tps = _binary_classification_thresholds(y_true, y_pred, thresholds)
    else:
        # Invert 0<-->1 so that true positives are true negatives, and false positives
        # are false negatives.
        fps, tps = _binary_classification_thresholds(1 - y_true, 1 - y_pred, thresholds)

    N = len(y_true)

    loss_over_profit = divide(
        thresholds, 1 - thresholds, where=thresholds < 1, out=zeros_like(thresholds)
    )
    benefit = tps / N - fps / N * loss_over_profit
    return thresholds, benefit
def net_benefit_action(y_true, threshold, action: bool = True)

Net benefit of always doing an action/intervention/treatment.

Args

action
When False, invert positive label in y_true.
Expand source code
def net_benefit_action(y_true, threshold, action: bool = True):
    """Net benefit of always doing an action/intervention/treatment.

    Args:
        action: When `False`, invert positive label in `y_true`.
    """
    if action:
        return y_true.mean() - (1 - y_true.mean()) * threshold / (1 - threshold)
    return 1 - y_true.mean() - y_true.mean() * (1 - threshold) / threshold
def net_benefit_oracle(y_true, action: bool = True) ‑> float

Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor.

Expand source code
def net_benefit_oracle(y_true, action: bool = True) -> float:
    """Net benefit of omniscient strategy, i.e., a hypothetical perfect predictor."""
    if action:
        return y_true.mean()
    return 1 - y_true.mean()
def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100)

Net benefit combining both taking and not-taking action.

Expand source code
def overall_net_benefit(y_true, y_pred, n_thresholds: int = 100):
    """Net benefit combining both taking and not-taking action."""
    thresholds_action, benefit_action = net_benefit(
        y_true, y_pred, n_thresholds, action=True
    )
    _, benefit_no_action = net_benefit(y_true, y_pred, n_thresholds, action=False)
    return thresholds_action, benefit_action + benefit_no_action

Classes

class NetBenefitDisplay (threshold_probability, net_benefit, oracle: Optional[float] = None, estimator_name: Optional[str] = None)

Net benefit decision curve analysis visualisation.

Args

threshold_probability
Probability to dichotomise the predicted probability of the model.
net_benefit
Net benefit of taking an action as a function of threshold_probability.
oracle
The (constant) net benefit of a perfect predictor.
Expand source code
class NetBenefitDisplay:
    """Net benefit decision curve analysis visualisation.

    Args:
        threshold_probability: Probability to dichotomise the predicted probability
            of the model.
        net_benefit: Net benefit of taking an action as a function of
            `threshold_probability`.
        oracle: The (constant) net benefit of a perfect predictor.
    """

    def __init__(
        self,
        threshold_probability,
        net_benefit,
        oracle: Optional[float] = None,
        estimator_name: Optional[str] = None,
    ):
        self.threshold_probability = threshold_probability
        self.net_benefit = net_benefit
        self.estimator_name = estimator_name
        self.oracle = oracle

    def plot(self, show_references: bool = True, ax=None):
        """
        Args:
            show_references: Show oracle (requires prevalence) and no
                action/treatment/intervention reference curves.
            ax: Optional axes object to plot on. If `None`, a new figure and axes is
                created.
        """
        if ax is None:
            fig, ax = plt.subplots()
        self.ax_ = ax
        self.figure_ = fig

        ax.plot(
            self.threshold_probability, self.net_benefit, "-", label=self.estimator_name
        )
        if show_references:
            ax.plot([0, 1], [0, 0], "-.", label="No action")
            if self.oracle is not None:
                ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle")

        ax.set(xlabel="Threshold probability", ylabel="Net benefit")
        ax.set_ylim([0, 1])
        ax.legend(loc="upper right", frameon=False)
        return self

    @classmethod
    def from_predictions(
        cls,
        y_true,
        y_pred,
        benefit: Literal["action", "noop", "overall"] = "action",
        name: Optional[str] = None,
        n_thresholds: int = 100,
        ax=None,
    ):
        """Make a net benefit plot from true and predicted labels.

        Args:
            y_true: Binary ground truth label (1: positive, 0: negative class).
            y_pred: Predicted class labels.
            benefit: Type of net benefit curve. `"action"`: net benefit of
                treatment/intervention/action; `"noop`: net benefit of no
                treatment/intervention/action; `"overall"`: overall net benefit (see
                `overall_net_benefit`).
            ax: Optional axes object to plot on. If `None`, a new figure and axes is
                created.

        Example:
            ```python
            from matplotlib import pyplot as plt
            from sklearn.datasets import make_blobs
            from sklearn.linear_model import LogisticRegression
            from sklearn.model_selection import train_test_split
            from statkit.decision import NetBenefitDisplay

            X, y = make_blobs(n_features=2, centers=2, cluster_std=0.5, random_state=5)
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
            clf = LogisticRegression().fit(X_train, y_train)
            y_pred = clf.predict_proba(X_test)[:, 1]

            NetBenefitDisplay.from_predictions(y_test, y_pred, name='model')
            plt.show()
            ```
        """
        if benefit == "action":
            oracle = net_benefit_oracle(y_true)
            thresholds, benefit = net_benefit(y_true, y_pred, n_thresholds)
        elif benefit == "noop":
            oracle = net_benefit_oracle(y_true, action=False)
            thresholds, benefit = net_benefit(
                y_true, y_pred, n_thresholds, action=False
            )
        elif benefit == "overall":
            oracle = 1.0
            thresholds, benefit = overall_net_benefit(y_true, y_pred, n_thresholds)

        return cls(
            thresholds,
            benefit,
            oracle,
            estimator_name=name,
        ).plot(ax=ax)

Static methods

def from_predictions(y_true, y_pred, benefit: Literal['action', 'noop', 'overall'] = 'action', name: Optional[str] = None, n_thresholds: int = 100, ax=None)

Make a net benefit plot from true and predicted labels.

Args

y_true
Binary ground truth label (1: positive, 0: negative class).
y_pred
Predicted class labels.
benefit
Type of net benefit curve. "action": net benefit of treatment/intervention/action; "noop: net benefit of no treatment/intervention/action; "overall": overall net benefit (see overall_net_benefit()).
ax
Optional axes object to plot on. If None, a new figure and axes is created.

Example

from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from statkit.decision import NetBenefitDisplay

X, y = make_blobs(n_features=2, centers=2, cluster_std=0.5, random_state=5)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
clf = LogisticRegression().fit(X_train, y_train)
y_pred = clf.predict_proba(X_test)[:, 1]

NetBenefitDisplay.from_predictions(y_test, y_pred, name='model')
plt.show()
Expand source code
@classmethod
def from_predictions(
    cls,
    y_true,
    y_pred,
    benefit: Literal["action", "noop", "overall"] = "action",
    name: Optional[str] = None,
    n_thresholds: int = 100,
    ax=None,
):
    """Make a net benefit plot from true and predicted labels.

    Args:
        y_true: Binary ground truth label (1: positive, 0: negative class).
        y_pred: Predicted class labels.
        benefit: Type of net benefit curve. `"action"`: net benefit of
            treatment/intervention/action; `"noop`: net benefit of no
            treatment/intervention/action; `"overall"`: overall net benefit (see
            `overall_net_benefit`).
        ax: Optional axes object to plot on. If `None`, a new figure and axes is
            created.

    Example:
        ```python
        from matplotlib import pyplot as plt
        from sklearn.datasets import make_blobs
        from sklearn.linear_model import LogisticRegression
        from sklearn.model_selection import train_test_split
        from statkit.decision import NetBenefitDisplay

        X, y = make_blobs(n_features=2, centers=2, cluster_std=0.5, random_state=5)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
        clf = LogisticRegression().fit(X_train, y_train)
        y_pred = clf.predict_proba(X_test)[:, 1]

        NetBenefitDisplay.from_predictions(y_test, y_pred, name='model')
        plt.show()
        ```
    """
    if benefit == "action":
        oracle = net_benefit_oracle(y_true)
        thresholds, benefit = net_benefit(y_true, y_pred, n_thresholds)
    elif benefit == "noop":
        oracle = net_benefit_oracle(y_true, action=False)
        thresholds, benefit = net_benefit(
            y_true, y_pred, n_thresholds, action=False
        )
    elif benefit == "overall":
        oracle = 1.0
        thresholds, benefit = overall_net_benefit(y_true, y_pred, n_thresholds)

    return cls(
        thresholds,
        benefit,
        oracle,
        estimator_name=name,
    ).plot(ax=ax)

Methods

def plot(self, show_references: bool = True, ax=None)

Args

show_references
Show oracle (requires prevalence) and no action/treatment/intervention reference curves.
ax
Optional axes object to plot on. If None, a new figure and axes is created.
Expand source code
def plot(self, show_references: bool = True, ax=None):
    """
    Args:
        show_references: Show oracle (requires prevalence) and no
            action/treatment/intervention reference curves.
        ax: Optional axes object to plot on. If `None`, a new figure and axes is
            created.
    """
    if ax is None:
        fig, ax = plt.subplots()
    self.ax_ = ax
    self.figure_ = fig

    ax.plot(
        self.threshold_probability, self.net_benefit, "-", label=self.estimator_name
    )
    if show_references:
        ax.plot([0, 1], [0, 0], "-.", label="No action")
        if self.oracle is not None:
            ax.plot([0, 1], [self.oracle, self.oracle], "--", label="Oracle")

    ax.set(xlabel="Threshold probability", ylabel="Net benefit")
    ax.set_ylim([0, 1])
    ax.legend(loc="upper right", frameon=False)
    return self