mgplot.bar_plot

bar_plot.py This module contains functions to create bar plots using Matplotlib. Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.

  1"""
  2bar_plot.py
  3This module contains functions to create bar plots using Matplotlib.
  4Note: bar plots in Matplotlib are not the same as bar charts in other
  5libraries. Bar plots are used to represent categorical data with
  6rectangular bars. As a result, bar plots and line plots typically
  7cannot be plotted on the same axes.
  8"""
  9
 10# --- imports
 11from typing import Any, Final
 12from collections.abc import Sequence
 13
 14import numpy as np
 15from pandas import Series, DataFrame, PeriodIndex
 16import matplotlib.pyplot as plt
 17from matplotlib.pyplot import Axes
 18import matplotlib.patheffects as pe
 19
 20
 21from mgplot.settings import DataT, get_setting
 22from mgplot.utilities import (
 23    apply_defaults,
 24    get_color_list,
 25    get_axes,
 26    constrain_data,
 27    default_rounding,
 28)
 29from mgplot.kw_type_checking import (
 30    ExpectedTypeDict,
 31    validate_expected,
 32    report_kwargs,
 33    validate_kwargs,
 34)
 35from mgplot.axis_utils import set_labels, map_periodindex, is_categorical
 36from mgplot.keyword_names import (
 37    AX,
 38    STACKED,
 39    ROTATION,
 40    MAX_TICKS,
 41    PLOT_FROM,
 42    COLOR,
 43    LABEL_SERIES,
 44    WIDTH,
 45    ANNOTATE,
 46    FONTSIZE,
 47    FONTNAME,
 48    ROUNDING,
 49    ANNOTATE_COLOR,
 50    ABOVE,
 51)
 52
 53
 54# --- constants
 55
 56BAR_KW_TYPES: Final[ExpectedTypeDict] = {
 57    # --- options for the entire bar plot
 58    AX: (Axes, type(None)),  # axes to plot on, or None for new axes
 59    STACKED: bool,  # if True, the bars will be stacked. If False, they will be grouped.
 60    MAX_TICKS: int,
 61    PLOT_FROM: (int, PeriodIndex, type(None)),
 62    # --- options for each bar ...
 63    COLOR: (str, Sequence, (str,)),
 64    LABEL_SERIES: (bool, Sequence, (bool,)),
 65    WIDTH: (float, int),
 66    # - options for bar annotations
 67    ANNOTATE: (type(None), bool),  # None, True
 68    FONTSIZE: (int, float, str),
 69    FONTNAME: (str),
 70    ROUNDING: int,
 71    ROTATION: (int, float),  # rotation of annotations in degrees
 72    ANNOTATE_COLOR: (str, type(None)),  # color of annotations
 73    ABOVE: bool,  # if True, annotations are above the bar
 74}
 75validate_expected(BAR_KW_TYPES, "bar_plot")
 76
 77
 78# --- functions
 79def annotate_bars(
 80    series: Series,
 81    offset: float,
 82    base: np.ndarray[tuple[int, ...], np.dtype[Any]],
 83    axes: Axes,
 84    **anno_kwargs,
 85) -> None:
 86    """Bar plot annotations.
 87    
 88    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 89    """
 90
 91    # --- only annotate in limited circumstances
 92    if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]:
 93        return
 94    max_annotations = 30
 95    if len(series) > max_annotations:
 96        return
 97
 98    # --- internal logic check
 99    if len(base) != len(series):
100        print(
101            f"Warning: base array length {len(base)} does not match series length {len(series)}."
102        )
103        return
104
105    # --- assemble the annotation parameters
106    above: Final[bool | None] = anno_kwargs.get(ABOVE, False)  # None is also False-ish
107    annotate_style = {
108        FONTSIZE: anno_kwargs.get(FONTSIZE),
109        FONTNAME: anno_kwargs.get(FONTNAME),
110        COLOR: anno_kwargs.get(COLOR),
111        ROTATION: anno_kwargs.get(ROTATION),
112    }
113    rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None))
114    adjustment = (series.max() - series.min()) * 0.02
115    zero_correction = series.index.min()
116
117    # --- annotate each bar
118    for index, value in zip(series.index.astype(int), series):  # mypy syntactic sugar
119        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
120        if above:
121            position += value
122        text = axes.text(
123            x=index + offset,
124            y=position,
125            s=f"{value:.{rounding}f}",
126            ha="center",
127            va="bottom" if value >= 0 else "top",
128            **annotate_style,
129        )
130        if not above and "foreground" in anno_kwargs:
131            # apply a stroke-effect to within bar annotations
132            # to make them more readable with very small bars.
133            text.set_path_effects(
134                [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]
135            )
136
137
138def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None:
139    """
140    plot a grouped bar plot
141    """
142
143    series_count = len(df.columns)
144
145    for i, col in enumerate(df.columns):
146        series = df[col]
147        if series.isnull().all():
148            continue
149        width = kwargs["width"][i]
150        if width < 0 or width > 1:
151            width = 0.8
152        adjusted_width = width / series_count  # 0.8
153        # far-left + margin + halfway through one grouped column
154        left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0)
155        offset = left + (i * adjusted_width)
156        foreground = kwargs["color"][i]
157        axes.bar(
158            x=series.index + offset,
159            height=series,
160            color=foreground,
161            width=adjusted_width,
162            label=col if kwargs[LABEL_SERIES][i] else f"_{col}_",
163        )
164        annotate_bars(
165            series=series,
166            offset=offset,
167            base=np.zeros(len(series)),
168            axes=axes,
169            foreground=foreground,
170            **anno_args,
171        )
172
173
174def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None:
175    """
176    plot a stacked bar plot
177    """
178
179    series_count = len(df)
180    base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
181        shape=series_count, dtype=np.float64
182    )
183    base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
184        shape=series_count, dtype=np.float64
185    )
186    for i, col in enumerate(df.columns):
187        series = df[col]
188        base = np.where(series >= 0, base_plus, base_minus)
189        foreground = kwargs["color"][i]
190        axes.bar(
191            x=series.index,
192            height=series,
193            bottom=base,
194            color=foreground,
195            width=kwargs[WIDTH][i],
196            label=col if kwargs[LABEL_SERIES][i] else f"_{col}_",
197        )
198        annotate_bars(
199            series=series,
200            offset=0,
201            base=base,
202            axes=axes,
203            foreground=foreground,
204            **anno_args
205        )
206        base_plus += np.where(series >= 0, series, 0)
207        base_minus += np.where(series < 0, series, 0)
208
209
210def bar_plot(
211    data: DataT,
212    **kwargs,
213) -> Axes:
214    """
215    Create a bar plot from the given data. Each column in the DataFrame
216    will be stacked on top of each other, with positive values above
217    zero and negative values below zero.
218
219    Parameters
220    - data: Series - The data to plot. Can be a DataFrame or a Series.
221    - **kwargs: dict Additional keyword arguments for customization.
222    # --- options for the entire bar plot
223    ax: Axes - axes to plot on, or None for new axes
224    stacked: bool - if True, the bars will be stacked. If False, they will be grouped.
225    max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only)
226    plot_from: int | PeriodIndex - if provided, the plot will start from this index.
227    # --- options for each bar ...
228    color: str | list[str] - the color of the bars (or separate colors for each series
229    label_series: bool | list[bool] - if True, the series will be labeled in the legend
230    width: float | list[float] - the width of the bars
231    # - options for bar annotations
232    annotate: bool - If True them annotate the bars with their values.
233    fontsize: int | float | str - font size of the annotations
234    fontname: str - font name of the annotations
235    rounding: int - number of decimal places to round to
236    annotate_color: str  - color of annotations
237    rotation: int | float - rotation of annotations in degrees
238    above: bool - if True, annotations are above the bar, else within the bar
239
240    Note: This function does not assume all data is timeseries with a PeriodIndex,
241
242    Returns
243    - axes: Axes - The axes for the plot.
244    """
245
246    # --- check the kwargs
247    me = "bar_plot"
248    report_kwargs(called_from=me, **kwargs)
249    kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs)
250
251    # --- get the data
252    # no call to check_clean_timeseries here, as bar plots are not
253    # necessarily timeseries data. If the data is a Series, it will be
254    # converted to a DataFrame with a single column.
255    df = DataFrame(data)  # really we are only plotting DataFrames
256    df, kwargs = constrain_data(df, **kwargs)
257    item_count = len(df.columns)
258
259    # --- deal with complete PeriodIdex indicies
260    if not is_categorical(df):
261        print(
262            "Warning: bar_plot is not designed for incomplete or non-categorical data indexes."
263        )
264    saved_pi = map_periodindex(df)
265    if saved_pi is not None:
266        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
267
268    # --- set up the default arguments
269    chart_defaults: dict[str, Any] = {
270        STACKED: False,
271        MAX_TICKS: 10,
272        LABEL_SERIES: item_count > 1,
273    }
274    chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()}
275
276    bar_defaults: dict[str, Any] = {
277        COLOR: get_color_list(item_count),
278        WIDTH: get_setting("bar_width"),
279        LABEL_SERIES: (item_count > 1),
280    }
281    above = kwargs.get(ABOVE, False)
282    anno_args = {
283        ANNOTATE: kwargs.get(ANNOTATE, False),
284        FONTSIZE: kwargs.get(FONTSIZE, "small"),
285        FONTNAME: kwargs.get(FONTNAME, "Helvetica"),
286        ROTATION: kwargs.get(ROTATION, 0),
287        ROUNDING: kwargs.get(ROUNDING, True),
288        COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"),
289        ABOVE: above,
290    }
291    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs)
292
293    # --- plot the data
294    axes, _rkwargs = get_axes(**remaining_kwargs)
295    if chart_args[STACKED]:
296        stacked(axes, df, anno_args, **bar_args)
297    else:
298        grouped(axes, df, anno_args, **bar_args)
299
300    # --- handle complete periodIndex data and label rotation
301    rotate_labels = True
302    if saved_pi is not None:
303        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
304        rotate_labels = False
305
306    if rotate_labels:
307        plt.xticks(rotation=90)
308
309    return axes
BAR_KW_TYPES: Final[ExpectedTypeDict] = {'ax': (<class 'matplotlib.axes._axes.Axes'>, <class 'NoneType'>), 'stacked': <class 'bool'>, 'max_ticks': <class 'int'>, 'plot_from': (<class 'int'>, <class 'pandas.core.indexes.period.PeriodIndex'>, <class 'NoneType'>), 'color': (<class 'str'>, <class 'collections.abc.Sequence'>, (<class 'str'>,)), 'label_series': (<class 'bool'>, <class 'collections.abc.Sequence'>, (<class 'bool'>,)), 'width': (<class 'float'>, <class 'int'>), 'annotate': (<class 'NoneType'>, <class 'bool'>), 'fontsize': (<class 'int'>, <class 'float'>, <class 'str'>), 'fontname': <class 'str'>, 'rounding': <class 'int'>, 'rotation': (<class 'int'>, <class 'float'>), 'annotate_color': (<class 'str'>, <class 'NoneType'>), 'above': <class 'bool'>}
def annotate_bars( series: pandas.core.series.Series, offset: float, base: numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]], axes: matplotlib.axes._axes.Axes, **anno_kwargs) -> None:
 80def annotate_bars(
 81    series: Series,
 82    offset: float,
 83    base: np.ndarray[tuple[int, ...], np.dtype[Any]],
 84    axes: Axes,
 85    **anno_kwargs,
 86) -> None:
 87    """Bar plot annotations.
 88    
 89    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 90    """
 91
 92    # --- only annotate in limited circumstances
 93    if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]:
 94        return
 95    max_annotations = 30
 96    if len(series) > max_annotations:
 97        return
 98
 99    # --- internal logic check
100    if len(base) != len(series):
101        print(
102            f"Warning: base array length {len(base)} does not match series length {len(series)}."
103        )
104        return
105
106    # --- assemble the annotation parameters
107    above: Final[bool | None] = anno_kwargs.get(ABOVE, False)  # None is also False-ish
108    annotate_style = {
109        FONTSIZE: anno_kwargs.get(FONTSIZE),
110        FONTNAME: anno_kwargs.get(FONTNAME),
111        COLOR: anno_kwargs.get(COLOR),
112        ROTATION: anno_kwargs.get(ROTATION),
113    }
114    rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None))
115    adjustment = (series.max() - series.min()) * 0.02
116    zero_correction = series.index.min()
117
118    # --- annotate each bar
119    for index, value in zip(series.index.astype(int), series):  # mypy syntactic sugar
120        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
121        if above:
122            position += value
123        text = axes.text(
124            x=index + offset,
125            y=position,
126            s=f"{value:.{rounding}f}",
127            ha="center",
128            va="bottom" if value >= 0 else "top",
129            **annotate_style,
130        )
131        if not above and "foreground" in anno_kwargs:
132            # apply a stroke-effect to within bar annotations
133            # to make them more readable with very small bars.
134            text.set_path_effects(
135                [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]
136            )

Bar plot annotations.

Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.

def grouped(axes, df: pandas.core.frame.DataFrame, anno_args, **kwargs) -> None:
139def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None:
140    """
141    plot a grouped bar plot
142    """
143
144    series_count = len(df.columns)
145
146    for i, col in enumerate(df.columns):
147        series = df[col]
148        if series.isnull().all():
149            continue
150        width = kwargs["width"][i]
151        if width < 0 or width > 1:
152            width = 0.8
153        adjusted_width = width / series_count  # 0.8
154        # far-left + margin + halfway through one grouped column
155        left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0)
156        offset = left + (i * adjusted_width)
157        foreground = kwargs["color"][i]
158        axes.bar(
159            x=series.index + offset,
160            height=series,
161            color=foreground,
162            width=adjusted_width,
163            label=col if kwargs[LABEL_SERIES][i] else f"_{col}_",
164        )
165        annotate_bars(
166            series=series,
167            offset=offset,
168            base=np.zeros(len(series)),
169            axes=axes,
170            foreground=foreground,
171            **anno_args,
172        )

plot a grouped bar plot

def stacked(axes, df: pandas.core.frame.DataFrame, anno_args, **kwargs) -> None:
175def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None:
176    """
177    plot a stacked bar plot
178    """
179
180    series_count = len(df)
181    base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
182        shape=series_count, dtype=np.float64
183    )
184    base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros(
185        shape=series_count, dtype=np.float64
186    )
187    for i, col in enumerate(df.columns):
188        series = df[col]
189        base = np.where(series >= 0, base_plus, base_minus)
190        foreground = kwargs["color"][i]
191        axes.bar(
192            x=series.index,
193            height=series,
194            bottom=base,
195            color=foreground,
196            width=kwargs[WIDTH][i],
197            label=col if kwargs[LABEL_SERIES][i] else f"_{col}_",
198        )
199        annotate_bars(
200            series=series,
201            offset=0,
202            base=base,
203            axes=axes,
204            foreground=foreground,
205            **anno_args
206        )
207        base_plus += np.where(series >= 0, series, 0)
208        base_minus += np.where(series < 0, series, 0)

plot a stacked bar plot

def bar_plot(data: ~DataT, **kwargs) -> matplotlib.axes._axes.Axes:
211def bar_plot(
212    data: DataT,
213    **kwargs,
214) -> Axes:
215    """
216    Create a bar plot from the given data. Each column in the DataFrame
217    will be stacked on top of each other, with positive values above
218    zero and negative values below zero.
219
220    Parameters
221    - data: Series - The data to plot. Can be a DataFrame or a Series.
222    - **kwargs: dict Additional keyword arguments for customization.
223    # --- options for the entire bar plot
224    ax: Axes - axes to plot on, or None for new axes
225    stacked: bool - if True, the bars will be stacked. If False, they will be grouped.
226    max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only)
227    plot_from: int | PeriodIndex - if provided, the plot will start from this index.
228    # --- options for each bar ...
229    color: str | list[str] - the color of the bars (or separate colors for each series
230    label_series: bool | list[bool] - if True, the series will be labeled in the legend
231    width: float | list[float] - the width of the bars
232    # - options for bar annotations
233    annotate: bool - If True them annotate the bars with their values.
234    fontsize: int | float | str - font size of the annotations
235    fontname: str - font name of the annotations
236    rounding: int - number of decimal places to round to
237    annotate_color: str  - color of annotations
238    rotation: int | float - rotation of annotations in degrees
239    above: bool - if True, annotations are above the bar, else within the bar
240
241    Note: This function does not assume all data is timeseries with a PeriodIndex,
242
243    Returns
244    - axes: Axes - The axes for the plot.
245    """
246
247    # --- check the kwargs
248    me = "bar_plot"
249    report_kwargs(called_from=me, **kwargs)
250    kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs)
251
252    # --- get the data
253    # no call to check_clean_timeseries here, as bar plots are not
254    # necessarily timeseries data. If the data is a Series, it will be
255    # converted to a DataFrame with a single column.
256    df = DataFrame(data)  # really we are only plotting DataFrames
257    df, kwargs = constrain_data(df, **kwargs)
258    item_count = len(df.columns)
259
260    # --- deal with complete PeriodIdex indicies
261    if not is_categorical(df):
262        print(
263            "Warning: bar_plot is not designed for incomplete or non-categorical data indexes."
264        )
265    saved_pi = map_periodindex(df)
266    if saved_pi is not None:
267        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
268
269    # --- set up the default arguments
270    chart_defaults: dict[str, Any] = {
271        STACKED: False,
272        MAX_TICKS: 10,
273        LABEL_SERIES: item_count > 1,
274    }
275    chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()}
276
277    bar_defaults: dict[str, Any] = {
278        COLOR: get_color_list(item_count),
279        WIDTH: get_setting("bar_width"),
280        LABEL_SERIES: (item_count > 1),
281    }
282    above = kwargs.get(ABOVE, False)
283    anno_args = {
284        ANNOTATE: kwargs.get(ANNOTATE, False),
285        FONTSIZE: kwargs.get(FONTSIZE, "small"),
286        FONTNAME: kwargs.get(FONTNAME, "Helvetica"),
287        ROTATION: kwargs.get(ROTATION, 0),
288        ROUNDING: kwargs.get(ROUNDING, True),
289        COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"),
290        ABOVE: above,
291    }
292    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs)
293
294    # --- plot the data
295    axes, _rkwargs = get_axes(**remaining_kwargs)
296    if chart_args[STACKED]:
297        stacked(axes, df, anno_args, **bar_args)
298    else:
299        grouped(axes, df, anno_args, **bar_args)
300
301    # --- handle complete periodIndex data and label rotation
302    rotate_labels = True
303    if saved_pi is not None:
304        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
305        rotate_labels = False
306
307    if rotate_labels:
308        plt.xticks(rotation=90)
309
310    return axes

Create a bar plot from the given data. Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.

Parameters

  • data: Series - The data to plot. Can be a DataFrame or a Series.
  • **kwargs: dict Additional keyword arguments for customization.

--- options for the entire bar plot

ax: Axes - axes to plot on, or None for new axes stacked: bool - if True, the bars will be stacked. If False, they will be grouped. max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) plot_from: int | PeriodIndex - if provided, the plot will start from this index.

--- options for each bar ...

color: str | list[str] - the color of the bars (or separate colors for each series label_series: bool | list[bool] - if True, the series will be labeled in the legend width: float | list[float] - the width of the bars

- options for bar annotations

annotate: bool - If True them annotate the bars with their values. fontsize: int | float | str - font size of the annotations fontname: str - font name of the annotations rounding: int - number of decimal places to round to annotate_color: str - color of annotations rotation: int | float - rotation of annotations in degrees above: bool - if True, annotations are above the bar, else within the bar

Note: This function does not assume all data is timeseries with a PeriodIndex,

Returns

  • axes: Axes - The axes for the plot.