mgplot.bar_plot

bar_plot.py This module contains functions to create bar plots using Matplotlib. Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.

  1"""
  2bar_plot.py
  3This module contains functions to create bar plots using Matplotlib.
  4Note: bar plots in Matplotlib are not the same as bar charts in other
  5libraries. Bar plots are used to represent categorical data with
  6rectangular bars. As a result, bar plots and line plots typically
  7cannot be plotted on the same axes.
  8"""
  9
 10# --- imports
 11from typing import Any, Final, Unpack, NotRequired
 12from collections.abc import Sequence
 13
 14import numpy as np
 15from pandas import Series, DataFrame, Period
 16import matplotlib.pyplot as plt
 17from matplotlib.pyplot import Axes
 18import matplotlib.patheffects as pe
 19
 20
 21from mgplot.settings import DataT, get_setting
 22from mgplot.utilities import (
 23    apply_defaults,
 24    get_color_list,
 25    get_axes,
 26    constrain_data,
 27    default_rounding,
 28)
 29from mgplot.keyword_checking import validate_kwargs, report_kwargs, BaseKwargs
 30from mgplot.axis_utils import set_labels, map_periodindex, is_categorical
 31
 32
 33# --- constants
 34ME: Final[str] = "bar_plot"
 35
 36
 37class BarKwargs(BaseKwargs):
 38    """Keyword arguments for the bar_plot function."""
 39
 40    # --- options for the entire bar plot
 41    ax: NotRequired[Axes | None]
 42    stacked: NotRequired[bool]
 43    max_ticks: NotRequired[int]
 44    plot_from: NotRequired[int | Period]
 45    # --- options for each bar ...
 46    color: NotRequired[str | Sequence[str]]
 47    label_series: NotRequired[bool | Sequence[bool]]
 48    width: NotRequired[float | int | Sequence[float | int]]
 49    # --- options for bar annotations
 50    annotate: NotRequired[bool]
 51    fontsize: NotRequired[int | float | str]
 52    fontname: NotRequired[str]
 53    rounding: NotRequired[int]
 54    rotation: NotRequired[int | float]
 55    annotate_color: NotRequired[str]
 56    above: NotRequired[bool]
 57
 58
 59# --- functions
 60def annotate_bars(
 61    series: Series,
 62    offset: float,
 63    base: np.ndarray[tuple[int, ...], np.dtype[Any]],
 64    axes: Axes,
 65    **anno_kwargs,
 66) -> None:
 67    """Bar plot annotations.
 68
 69    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 70    """
 71
 72    # --- only annotate in limited circumstances
 73    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 74        return
 75    max_annotations = 30
 76    if len(series) > max_annotations:
 77        return
 78
 79    # --- internal logic check
 80    if len(base) != len(series):
 81        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 82        return
 83
 84    # --- assemble the annotation parameters
 85    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
 86    annotate_style = {
 87        "fontsize": anno_kwargs.get("fontsize"),
 88        "fontname": anno_kwargs.get("fontname"),
 89        "color": anno_kwargs.get("color"),
 90        "rotation": anno_kwargs.get("rotation"),
 91    }
 92    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding", None))
 93    adjustment = (series.max() - series.min()) * 0.02
 94    zero_correction = series.index.min()
 95
 96    # --- annotate each bar
 97    for index, value in zip(series.index.astype(int), series):  # mypy syntactic sugar
 98        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
 99        if above:
100            position += value
101        text = axes.text(
102            x=index + offset,
103            y=position,
104            s=f"{value:.{rounding}f}",
105            ha="center",
106            va="bottom" if value >= 0 else "top",
107            **annotate_style,
108        )
109        if not above and "foreground" in anno_kwargs:
110            # apply a stroke-effect to within bar annotations
111            # to make them more readable with very small bars.
112            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
113
114
115def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None:
116    """
117    plot a grouped bar plot
118    """
119
120    series_count = len(df.columns)
121
122    for i, col in enumerate(df.columns):
123        series = df[col]
124        if series.isnull().all():
125            continue
126        width = kwargs["width"][i]
127        if width < 0 or width > 1:
128            width = 0.8
129        adjusted_width = width / series_count  # 0.8
130        # far-left + margin + halfway through one grouped column
131        left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0)
132        offset = left + (i * adjusted_width)
133        foreground = kwargs["color"][i]
134        axes.bar(
135            x=series.index + offset,
136            height=series,
137            color=foreground,
138            width=adjusted_width,
139            label=col if kwargs["label_series"][i] else f"_{col}_",
140        )
141        annotate_bars(
142            series=series,
143            offset=offset,
144            base=np.zeros(len(series)),
145            axes=axes,
146            foreground=foreground,
147            **anno_args,
148        )
149
150
151def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None:
152    """
153    plot a stacked bar plot
154    """
155
156    series_count = len(df)
157    base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
158        shape=series_count, dtype=np.float64
159    )
160    base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
161        shape=series_count, dtype=np.float64
162    )
163    for i, col in enumerate(df.columns):
164        series = df[col]
165        base = np.where(series >= 0, base_plus, base_minus)
166        foreground = kwargs["color"][i]
167        axes.bar(
168            x=series.index,
169            height=series,
170            bottom=base,
171            color=foreground,
172            width=kwargs["width"][i],
173            label=col if kwargs["label_series"][i] else f"_{col}_",
174        )
175        annotate_bars(
176            series=series,
177            offset=0,
178            base=base,
179            axes=axes,
180            foreground=foreground,
181            **anno_args,
182        )
183        base_plus += np.where(series >= 0, series, 0)
184        base_minus += np.where(series < 0, series, 0)
185
186
187def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
188    """
189    Create a bar plot from the given data. Each column in the DataFrame
190    will be stacked on top of each other, with positive values above
191    zero and negative values below zero.
192
193    Parameters
194    - data: Series - The data to plot. Can be a DataFrame or a Series.
195    - **kwargs: BarKwargs - Additional keyword arguments for customization.
196      (see BarKwargs for details)
197
198    Note: This function does not assume all data is timeseries with a PeriodIndex,
199
200    Returns
201    - axes: Axes - The axes for the plot.
202    """
203
204    # --- check the kwargs
205    report_kwargs(caller=ME, **kwargs)
206    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
207
208    # --- get the data
209    # no call to check_clean_timeseries here, as bar plots are not
210    # necessarily timeseries data. If the data is a Series, it will be
211    # converted to a DataFrame with a single column.
212    df = DataFrame(data)  # really we are only plotting DataFrames
213    df, kwargs_d = constrain_data(df, **kwargs)
214    item_count = len(df.columns)
215
216    # --- deal with complete PeriodIdex indicies
217    if not is_categorical(df):
218        print("Warning: bar_plot is not designed for incomplete or non-categorical data indexes.")
219    saved_pi = map_periodindex(df)
220    if saved_pi is not None:
221        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
222
223    # --- set up the default arguments
224    chart_defaults: dict[str, Any] = {
225        "stacked": False,
226        "max_ticks": 10,
227        "label_series": item_count > 1,
228    }
229    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
230
231    bar_defaults: dict[str, Any] = {
232        "color": get_color_list(item_count),
233        "width": get_setting("bar_width"),
234        "label_series": (item_count > 1),
235    }
236    above = kwargs_d.get("above", False)
237    anno_args = {
238        "annotate": kwargs_d.get("annotate", False),
239        "fontsize": kwargs_d.get("fontsize", "small"),
240        "fontname": kwargs_d.get("fontname", "Helvetica"),
241        "rotation": kwargs_d.get("rotation", 0),
242        "rounding": kwargs_d.get("rounding", True),
243        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
244        "above": above,
245    }
246    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
247
248    # --- plot the data
249    axes, remaining_kwargs = get_axes(**remaining_kwargs)
250    if chart_args["stacked"]:
251        stacked(axes, df, anno_args, **bar_args)
252    else:
253        grouped(axes, df, anno_args, **bar_args)
254
255    # --- handle complete periodIndex data and label rotation
256    if saved_pi is not None:
257        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
258    else:
259        plt.xticks(rotation=90)
260
261    return axes
ME: Final[str] = 'bar_plot'
class BarKwargs(mgplot.keyword_checking.BaseKwargs):
38class BarKwargs(BaseKwargs):
39    """Keyword arguments for the bar_plot function."""
40
41    # --- options for the entire bar plot
42    ax: NotRequired[Axes | None]
43    stacked: NotRequired[bool]
44    max_ticks: NotRequired[int]
45    plot_from: NotRequired[int | Period]
46    # --- options for each bar ...
47    color: NotRequired[str | Sequence[str]]
48    label_series: NotRequired[bool | Sequence[bool]]
49    width: NotRequired[float | int | Sequence[float | int]]
50    # --- options for bar annotations
51    annotate: NotRequired[bool]
52    fontsize: NotRequired[int | float | str]
53    fontname: NotRequired[str]
54    rounding: NotRequired[int]
55    rotation: NotRequired[int | float]
56    annotate_color: NotRequired[str]
57    above: NotRequired[bool]

Keyword arguments for the bar_plot function.

ax: NotRequired[matplotlib.axes._axes.Axes | None]
stacked: NotRequired[bool]
max_ticks: NotRequired[int]
plot_from: NotRequired[int | pandas._libs.tslibs.period.Period]
color: NotRequired[str | Sequence[str]]
label_series: NotRequired[bool | Sequence[bool]]
width: NotRequired[float | int | Sequence[float | int]]
annotate: NotRequired[bool]
fontsize: NotRequired[int | float | str]
fontname: NotRequired[str]
rounding: NotRequired[int]
rotation: NotRequired[int | float]
annotate_color: NotRequired[str]
above: NotRequired[bool]
def annotate_bars( series: pandas.core.series.Series, offset: float, base: numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]], axes: matplotlib.axes._axes.Axes, **anno_kwargs) -> None:
 61def annotate_bars(
 62    series: Series,
 63    offset: float,
 64    base: np.ndarray[tuple[int, ...], np.dtype[Any]],
 65    axes: Axes,
 66    **anno_kwargs,
 67) -> None:
 68    """Bar plot annotations.
 69
 70    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 71    """
 72
 73    # --- only annotate in limited circumstances
 74    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 75        return
 76    max_annotations = 30
 77    if len(series) > max_annotations:
 78        return
 79
 80    # --- internal logic check
 81    if len(base) != len(series):
 82        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 83        return
 84
 85    # --- assemble the annotation parameters
 86    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
 87    annotate_style = {
 88        "fontsize": anno_kwargs.get("fontsize"),
 89        "fontname": anno_kwargs.get("fontname"),
 90        "color": anno_kwargs.get("color"),
 91        "rotation": anno_kwargs.get("rotation"),
 92    }
 93    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding", None))
 94    adjustment = (series.max() - series.min()) * 0.02
 95    zero_correction = series.index.min()
 96
 97    # --- annotate each bar
 98    for index, value in zip(series.index.astype(int), series):  # mypy syntactic sugar
 99        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
100        if above:
101            position += value
102        text = axes.text(
103            x=index + offset,
104            y=position,
105            s=f"{value:.{rounding}f}",
106            ha="center",
107            va="bottom" if value >= 0 else "top",
108            **annotate_style,
109        )
110        if not above and "foreground" in anno_kwargs:
111            # apply a stroke-effect to within bar annotations
112            # to make them more readable with very small bars.
113            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])

Bar plot annotations.

Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.

def grouped(axes, df: pandas.core.frame.DataFrame, anno_args, **kwargs) -> None:
116def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None:
117    """
118    plot a grouped bar plot
119    """
120
121    series_count = len(df.columns)
122
123    for i, col in enumerate(df.columns):
124        series = df[col]
125        if series.isnull().all():
126            continue
127        width = kwargs["width"][i]
128        if width < 0 or width > 1:
129            width = 0.8
130        adjusted_width = width / series_count  # 0.8
131        # far-left + margin + halfway through one grouped column
132        left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0)
133        offset = left + (i * adjusted_width)
134        foreground = kwargs["color"][i]
135        axes.bar(
136            x=series.index + offset,
137            height=series,
138            color=foreground,
139            width=adjusted_width,
140            label=col if kwargs["label_series"][i] else f"_{col}_",
141        )
142        annotate_bars(
143            series=series,
144            offset=offset,
145            base=np.zeros(len(series)),
146            axes=axes,
147            foreground=foreground,
148            **anno_args,
149        )

plot a grouped bar plot

def stacked(axes, df: pandas.core.frame.DataFrame, anno_args, **kwargs) -> None:
152def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None:
153    """
154    plot a stacked bar plot
155    """
156
157    series_count = len(df)
158    base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
159        shape=series_count, dtype=np.float64
160    )
161    base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
162        shape=series_count, dtype=np.float64
163    )
164    for i, col in enumerate(df.columns):
165        series = df[col]
166        base = np.where(series >= 0, base_plus, base_minus)
167        foreground = kwargs["color"][i]
168        axes.bar(
169            x=series.index,
170            height=series,
171            bottom=base,
172            color=foreground,
173            width=kwargs["width"][i],
174            label=col if kwargs["label_series"][i] else f"_{col}_",
175        )
176        annotate_bars(
177            series=series,
178            offset=0,
179            base=base,
180            axes=axes,
181            foreground=foreground,
182            **anno_args,
183        )
184        base_plus += np.where(series >= 0, series, 0)
185        base_minus += np.where(series < 0, series, 0)

plot a stacked bar plot

def bar_plot( data: ~DataT, **kwargs: Unpack[BarKwargs]) -> matplotlib.axes._axes.Axes:
188def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
189    """
190    Create a bar plot from the given data. Each column in the DataFrame
191    will be stacked on top of each other, with positive values above
192    zero and negative values below zero.
193
194    Parameters
195    - data: Series - The data to plot. Can be a DataFrame or a Series.
196    - **kwargs: BarKwargs - Additional keyword arguments for customization.
197      (see BarKwargs for details)
198
199    Note: This function does not assume all data is timeseries with a PeriodIndex,
200
201    Returns
202    - axes: Axes - The axes for the plot.
203    """
204
205    # --- check the kwargs
206    report_kwargs(caller=ME, **kwargs)
207    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
208
209    # --- get the data
210    # no call to check_clean_timeseries here, as bar plots are not
211    # necessarily timeseries data. If the data is a Series, it will be
212    # converted to a DataFrame with a single column.
213    df = DataFrame(data)  # really we are only plotting DataFrames
214    df, kwargs_d = constrain_data(df, **kwargs)
215    item_count = len(df.columns)
216
217    # --- deal with complete PeriodIdex indicies
218    if not is_categorical(df):
219        print("Warning: bar_plot is not designed for incomplete or non-categorical data indexes.")
220    saved_pi = map_periodindex(df)
221    if saved_pi is not None:
222        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
223
224    # --- set up the default arguments
225    chart_defaults: dict[str, Any] = {
226        "stacked": False,
227        "max_ticks": 10,
228        "label_series": item_count > 1,
229    }
230    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
231
232    bar_defaults: dict[str, Any] = {
233        "color": get_color_list(item_count),
234        "width": get_setting("bar_width"),
235        "label_series": (item_count > 1),
236    }
237    above = kwargs_d.get("above", False)
238    anno_args = {
239        "annotate": kwargs_d.get("annotate", False),
240        "fontsize": kwargs_d.get("fontsize", "small"),
241        "fontname": kwargs_d.get("fontname", "Helvetica"),
242        "rotation": kwargs_d.get("rotation", 0),
243        "rounding": kwargs_d.get("rounding", True),
244        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
245        "above": above,
246    }
247    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
248
249    # --- plot the data
250    axes, remaining_kwargs = get_axes(**remaining_kwargs)
251    if chart_args["stacked"]:
252        stacked(axes, df, anno_args, **bar_args)
253    else:
254        grouped(axes, df, anno_args, **bar_args)
255
256    # --- handle complete periodIndex data and label rotation
257    if saved_pi is not None:
258        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
259    else:
260        plt.xticks(rotation=90)
261
262    return axes

Create a bar plot from the given data. Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.

Parameters

  • data: Series - The data to plot. Can be a DataFrame or a Series.
  • **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)

Note: This function does not assume all data is timeseries with a PeriodIndex,

Returns

  • axes: Axes - The axes for the plot.