mgplot.bar_plot

Create bar plots using Matplotlib.

Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.

  1"""Create bar plots using Matplotlib.
  2
  3Note: bar plots in Matplotlib are not the same as bar charts in other
  4libraries. Bar plots are used to represent categorical data with
  5rectangular bars. As a result, bar plots and line plots typically
  6cannot be plotted on the same axes.
  7"""
  8
  9from collections.abc import Sequence
 10from typing import Any, Final, NotRequired, TypedDict, Unpack
 11
 12import matplotlib.patheffects as pe
 13import matplotlib.pyplot as plt
 14import numpy as np
 15from matplotlib.axes import Axes
 16from pandas import DataFrame, Period, Series
 17
 18from mgplot.axis_utils import map_periodindex, set_labels
 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs
 20from mgplot.settings import DataT, get_setting
 21from mgplot.utilities import (
 22    apply_defaults,
 23    constrain_data,
 24    default_rounding,
 25    get_axes,
 26    get_color_list,
 27)
 28
 29# --- constants
 30ME: Final[str] = "bar_plot"
 31MAX_ANNOTATIONS: Final[int] = 30
 32ADJUSTMENT_FACTOR: Final[float] = 0.02
 33MIN_BAR_WIDTH: Final[float] = 0.0
 34MAX_BAR_WIDTH: Final[float] = 1.0
 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
 36DEFAULT_BAR_OFFSET: Final[float] = 0.5
 37DEFAULT_MAX_TICKS: Final[int] = 10
 38
 39
 40class BarKwargs(BaseKwargs):
 41    """Keyword arguments for the bar_plot function."""
 42
 43    # --- options for the entire bar plot
 44    ax: NotRequired[Axes | None]
 45    stacked: NotRequired[bool]
 46    max_ticks: NotRequired[int]
 47    plot_from: NotRequired[int | Period]
 48    # --- options for each bar ...
 49    color: NotRequired[str | Sequence[str]]
 50    label_series: NotRequired[bool | Sequence[bool]]
 51    width: NotRequired[float | int | Sequence[float | int]]
 52    # --- options for bar annotations
 53    annotate: NotRequired[bool]
 54    fontsize: NotRequired[int | float | str]
 55    fontname: NotRequired[str]
 56    rounding: NotRequired[int]
 57    rotation: NotRequired[int | float]
 58    annotate_color: NotRequired[str]
 59    above: NotRequired[bool]
 60
 61
 62# --- functions
 63class AnnoKwargs(TypedDict, total=False):
 64    """TypedDict for the kwargs used in annotate_bars."""
 65
 66    annotate: bool
 67    fontsize: int | float | str
 68    fontname: str
 69    color: str
 70    rotation: int | float
 71    foreground: str  # used for stroke effect on text
 72    above: bool
 73    rounding: bool | int  # if True, uses default rounding; if int, uses that value
 74
 75
 76def annotate_bars(
 77    series: Series,
 78    offset: float,
 79    base: np.ndarray,
 80    axes: Axes,
 81    **anno_kwargs: Unpack[AnnoKwargs],
 82) -> None:
 83    """Bar plot annotations.
 84
 85    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 86    """
 87    # --- only annotate in limited circumstances
 88    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 89        return
 90    max_annotations = MAX_ANNOTATIONS
 91    if len(series) > max_annotations:
 92        return
 93
 94    # --- internal logic check
 95    if len(base) != len(series):
 96        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 97        return
 98
 99    # --- assemble the annotation parameters
100    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
101    annotate_style: dict[str, Any] = {
102        "fontsize": anno_kwargs.get("fontsize"),
103        "fontname": anno_kwargs.get("fontname"),
104        "color": anno_kwargs.get("color"),
105        "rotation": anno_kwargs.get("rotation"),
106    }
107    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
108    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
109    zero_correction = series.index.min()
110
111    # --- annotate each bar
112    for index, value in zip(series.index.astype(int), series, strict=True):
113        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
114        if above:
115            position += value
116        text = axes.text(
117            x=index + offset,
118            y=position,
119            s=f"{value:.{rounding}f}",
120            ha="center",
121            va="bottom" if value >= 0 else "top",
122            **annotate_style,
123        )
124        if not above and "foreground" in anno_kwargs:
125            # apply a stroke-effect to within bar annotations
126            # to make them more readable with very small bars.
127            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
128
129
130class GroupedKwargs(TypedDict):
131    """TypedDict for the kwargs used in grouped."""
132
133    color: Sequence[str]
134    width: Sequence[float | int]
135    label_series: Sequence[bool]
136
137
138def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
139    """Plot a grouped bar plot."""
140    series_count = len(df.columns)
141
142    for i, col in enumerate(df.columns):
143        series = df[col]
144        if series.isna().all():
145            continue
146        width = kwargs["width"][i]
147        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
148            width = DEFAULT_GROUPED_WIDTH
149        adjusted_width = width / series_count
150        # far-left + margin + halfway through one grouped column
151        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
152        offset = left + (i * adjusted_width)
153        foreground = kwargs["color"][i]
154        axes.bar(
155            x=series.index + offset,
156            height=series,
157            color=foreground,
158            width=adjusted_width,
159            label=col if kwargs["label_series"][i] else f"_{col}_",
160        )
161        anno_args["foreground"] = foreground
162        annotate_bars(
163            series=series,
164            offset=offset,
165            base=np.zeros(len(series)),
166            axes=axes,
167            **anno_args,
168        )
169
170
171class StackedKwargs(TypedDict):
172    """TypedDict for the kwargs used in stacked."""
173
174    color: Sequence[str]
175    width: Sequence[float | int]
176    label_series: Sequence[bool]
177
178
179def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
180    """Plot a stacked bar plot."""
181    row_count = len(df)
182    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
183    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
184    for i, col in enumerate(df.columns):
185        series = df[col]
186        base = np.where(series >= 0, base_plus, base_minus)
187        foreground = kwargs["color"][i]
188        axes.bar(
189            x=series.index,
190            height=series,
191            bottom=base,
192            color=foreground,
193            width=kwargs["width"][i],
194            label=col if kwargs["label_series"][i] else f"_{col}_",
195        )
196        anno_args["foreground"] = foreground
197        annotate_bars(
198            series=series,
199            offset=0,
200            base=base,
201            axes=axes,
202            **anno_args,
203        )
204        base_plus += np.where(series >= 0, series, 0)
205        base_minus += np.where(series < 0, series, 0)
206
207
208def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
209    """Create a bar plot from the given data.
210
211    Each column in the DataFrame will be stacked on top of each other,
212    with positive values above zero and negative values below zero.
213
214    Args:
215        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
216        **kwargs: BarKwargs - Additional keyword arguments for customization.
217        (see BarKwargs for details)
218
219    Note: This function does not assume all data is timeseries with a PeriodIndex.
220
221    Returns:
222        axes: Axes - The axes for the plot.
223
224    """
225    # --- check the kwargs
226    report_kwargs(caller=ME, **kwargs)
227    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
228
229    # --- get the data
230    # no call to check_clean_timeseries here, as bar plots are not
231    # necessarily timeseries data. If the data is a Series, it will be
232    # converted to a DataFrame with a single column.
233    df = DataFrame(data)  # really we are only plotting DataFrames
234    df, kwargs_d = constrain_data(df, **kwargs)
235    item_count = len(df.columns)
236
237    # --- deal with complete PeriodIndex indices
238    saved_pi = map_periodindex(df)
239    if saved_pi is not None:
240        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
241
242    # --- set up the default arguments
243    chart_defaults: dict[str, bool | int] = {
244        "stacked": False,
245        "max_ticks": DEFAULT_MAX_TICKS,
246        "label_series": item_count > 1,
247    }
248    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
249
250    bar_defaults = {
251        "color": get_color_list(item_count),
252        "width": get_setting("bar_width"),
253        "label_series": item_count > 1,
254    }
255    above = kwargs_d.get("above", False)
256    anno_args: AnnoKwargs = {
257        "annotate": kwargs_d.get("annotate", False),
258        "fontsize": kwargs_d.get("fontsize", "small"),
259        "fontname": kwargs_d.get("fontname", "Helvetica"),
260        "rotation": kwargs_d.get("rotation", 0),
261        "rounding": kwargs_d.get("rounding", True),
262        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
263        "above": above,
264    }
265    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
266
267    # --- plot the data
268    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
269    if chart_args["stacked"]:
270        stacked(axes, df, anno_args, **bar_args)
271    else:
272        grouped(axes, df, anno_args, **bar_args)
273
274    # --- handle complete periodIndex data and label rotation
275    if saved_pi is not None:
276        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
277    else:
278        plt.xticks(rotation=90)
279
280    return axes
ME: Final[str] = 'bar_plot'
MAX_ANNOTATIONS: Final[int] = 30
ADJUSTMENT_FACTOR: Final[float] = 0.02
MIN_BAR_WIDTH: Final[float] = 0.0
MAX_BAR_WIDTH: Final[float] = 1.0
DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
DEFAULT_BAR_OFFSET: Final[float] = 0.5
DEFAULT_MAX_TICKS: Final[int] = 10
class BarKwargs(mgplot.keyword_checking.BaseKwargs):
41class BarKwargs(BaseKwargs):
42    """Keyword arguments for the bar_plot function."""
43
44    # --- options for the entire bar plot
45    ax: NotRequired[Axes | None]
46    stacked: NotRequired[bool]
47    max_ticks: NotRequired[int]
48    plot_from: NotRequired[int | Period]
49    # --- options for each bar ...
50    color: NotRequired[str | Sequence[str]]
51    label_series: NotRequired[bool | Sequence[bool]]
52    width: NotRequired[float | int | Sequence[float | int]]
53    # --- options for bar annotations
54    annotate: NotRequired[bool]
55    fontsize: NotRequired[int | float | str]
56    fontname: NotRequired[str]
57    rounding: NotRequired[int]
58    rotation: NotRequired[int | float]
59    annotate_color: NotRequired[str]
60    above: NotRequired[bool]

Keyword arguments for the bar_plot function.

ax: NotRequired[matplotlib.axes._axes.Axes | None]
stacked: NotRequired[bool]
max_ticks: NotRequired[int]
plot_from: NotRequired[int | pandas._libs.tslibs.period.Period]
color: NotRequired[str | Sequence[str]]
label_series: NotRequired[bool | Sequence[bool]]
width: NotRequired[float | int | Sequence[float | int]]
annotate: NotRequired[bool]
fontsize: NotRequired[int | float | str]
fontname: NotRequired[str]
rounding: NotRequired[int]
rotation: NotRequired[int | float]
annotate_color: NotRequired[str]
above: NotRequired[bool]
class AnnoKwargs(typing.TypedDict):
64class AnnoKwargs(TypedDict, total=False):
65    """TypedDict for the kwargs used in annotate_bars."""
66
67    annotate: bool
68    fontsize: int | float | str
69    fontname: str
70    color: str
71    rotation: int | float
72    foreground: str  # used for stroke effect on text
73    above: bool
74    rounding: bool | int  # if True, uses default rounding; if int, uses that value

TypedDict for the kwargs used in annotate_bars.

annotate: bool
fontsize: int | float | str
fontname: str
color: str
rotation: int | float
foreground: str
above: bool
rounding: bool | int
def annotate_bars( series: pandas.core.series.Series, offset: float, base: numpy.ndarray, axes: matplotlib.axes._axes.Axes, **anno_kwargs: Unpack[AnnoKwargs]) -> None:
 77def annotate_bars(
 78    series: Series,
 79    offset: float,
 80    base: np.ndarray,
 81    axes: Axes,
 82    **anno_kwargs: Unpack[AnnoKwargs],
 83) -> None:
 84    """Bar plot annotations.
 85
 86    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 87    """
 88    # --- only annotate in limited circumstances
 89    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 90        return
 91    max_annotations = MAX_ANNOTATIONS
 92    if len(series) > max_annotations:
 93        return
 94
 95    # --- internal logic check
 96    if len(base) != len(series):
 97        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 98        return
 99
100    # --- assemble the annotation parameters
101    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
102    annotate_style: dict[str, Any] = {
103        "fontsize": anno_kwargs.get("fontsize"),
104        "fontname": anno_kwargs.get("fontname"),
105        "color": anno_kwargs.get("color"),
106        "rotation": anno_kwargs.get("rotation"),
107    }
108    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
109    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
110    zero_correction = series.index.min()
111
112    # --- annotate each bar
113    for index, value in zip(series.index.astype(int), series, strict=True):
114        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
115        if above:
116            position += value
117        text = axes.text(
118            x=index + offset,
119            y=position,
120            s=f"{value:.{rounding}f}",
121            ha="center",
122            va="bottom" if value >= 0 else "top",
123            **annotate_style,
124        )
125        if not above and "foreground" in anno_kwargs:
126            # apply a stroke-effect to within bar annotations
127            # to make them more readable with very small bars.
128            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])

Bar plot annotations.

Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.

class GroupedKwargs(typing.TypedDict):
131class GroupedKwargs(TypedDict):
132    """TypedDict for the kwargs used in grouped."""
133
134    color: Sequence[str]
135    width: Sequence[float | int]
136    label_series: Sequence[bool]

TypedDict for the kwargs used in grouped.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
def grouped( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
139def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
140    """Plot a grouped bar plot."""
141    series_count = len(df.columns)
142
143    for i, col in enumerate(df.columns):
144        series = df[col]
145        if series.isna().all():
146            continue
147        width = kwargs["width"][i]
148        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
149            width = DEFAULT_GROUPED_WIDTH
150        adjusted_width = width / series_count
151        # far-left + margin + halfway through one grouped column
152        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
153        offset = left + (i * adjusted_width)
154        foreground = kwargs["color"][i]
155        axes.bar(
156            x=series.index + offset,
157            height=series,
158            color=foreground,
159            width=adjusted_width,
160            label=col if kwargs["label_series"][i] else f"_{col}_",
161        )
162        anno_args["foreground"] = foreground
163        annotate_bars(
164            series=series,
165            offset=offset,
166            base=np.zeros(len(series)),
167            axes=axes,
168            **anno_args,
169        )

Plot a grouped bar plot.

class StackedKwargs(typing.TypedDict):
172class StackedKwargs(TypedDict):
173    """TypedDict for the kwargs used in stacked."""
174
175    color: Sequence[str]
176    width: Sequence[float | int]
177    label_series: Sequence[bool]

TypedDict for the kwargs used in stacked.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
def stacked( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
180def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
181    """Plot a stacked bar plot."""
182    row_count = len(df)
183    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
184    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
185    for i, col in enumerate(df.columns):
186        series = df[col]
187        base = np.where(series >= 0, base_plus, base_minus)
188        foreground = kwargs["color"][i]
189        axes.bar(
190            x=series.index,
191            height=series,
192            bottom=base,
193            color=foreground,
194            width=kwargs["width"][i],
195            label=col if kwargs["label_series"][i] else f"_{col}_",
196        )
197        anno_args["foreground"] = foreground
198        annotate_bars(
199            series=series,
200            offset=0,
201            base=base,
202            axes=axes,
203            **anno_args,
204        )
205        base_plus += np.where(series >= 0, series, 0)
206        base_minus += np.where(series < 0, series, 0)

Plot a stacked bar plot.

def bar_plot( data: ~DataT, **kwargs: Unpack[BarKwargs]) -> matplotlib.axes._axes.Axes:
209def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
210    """Create a bar plot from the given data.
211
212    Each column in the DataFrame will be stacked on top of each other,
213    with positive values above zero and negative values below zero.
214
215    Args:
216        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
217        **kwargs: BarKwargs - Additional keyword arguments for customization.
218        (see BarKwargs for details)
219
220    Note: This function does not assume all data is timeseries with a PeriodIndex.
221
222    Returns:
223        axes: Axes - The axes for the plot.
224
225    """
226    # --- check the kwargs
227    report_kwargs(caller=ME, **kwargs)
228    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
229
230    # --- get the data
231    # no call to check_clean_timeseries here, as bar plots are not
232    # necessarily timeseries data. If the data is a Series, it will be
233    # converted to a DataFrame with a single column.
234    df = DataFrame(data)  # really we are only plotting DataFrames
235    df, kwargs_d = constrain_data(df, **kwargs)
236    item_count = len(df.columns)
237
238    # --- deal with complete PeriodIndex indices
239    saved_pi = map_periodindex(df)
240    if saved_pi is not None:
241        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
242
243    # --- set up the default arguments
244    chart_defaults: dict[str, bool | int] = {
245        "stacked": False,
246        "max_ticks": DEFAULT_MAX_TICKS,
247        "label_series": item_count > 1,
248    }
249    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
250
251    bar_defaults = {
252        "color": get_color_list(item_count),
253        "width": get_setting("bar_width"),
254        "label_series": item_count > 1,
255    }
256    above = kwargs_d.get("above", False)
257    anno_args: AnnoKwargs = {
258        "annotate": kwargs_d.get("annotate", False),
259        "fontsize": kwargs_d.get("fontsize", "small"),
260        "fontname": kwargs_d.get("fontname", "Helvetica"),
261        "rotation": kwargs_d.get("rotation", 0),
262        "rounding": kwargs_d.get("rounding", True),
263        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
264        "above": above,
265    }
266    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
267
268    # --- plot the data
269    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
270    if chart_args["stacked"]:
271        stacked(axes, df, anno_args, **bar_args)
272    else:
273        grouped(axes, df, anno_args, **bar_args)
274
275    # --- handle complete periodIndex data and label rotation
276    if saved_pi is not None:
277        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
278    else:
279        plt.xticks(rotation=90)
280
281    return axes

Create a bar plot from the given data.

Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.

Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)

Note: This function does not assume all data is timeseries with a PeriodIndex.

Returns: axes: Axes - The axes for the plot.