mgplot.bar_plot
bar_plot.py This module contains functions to create bar plots using Matplotlib. Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1""" 2bar_plot.py 3This module contains functions to create bar plots using Matplotlib. 4Note: bar plots in Matplotlib are not the same as bar charts in other 5libraries. Bar plots are used to represent categorical data with 6rectangular bars. As a result, bar plots and line plots typically 7cannot be plotted on the same axes. 8""" 9 10# --- imports 11from typing import Any, Final 12from collections.abc import Sequence 13 14import numpy as np 15from pandas import Series, DataFrame, PeriodIndex 16import matplotlib.pyplot as plt 17from matplotlib.pyplot import Axes 18import matplotlib.patheffects as pe 19 20 21from mgplot.settings import DataT, get_setting 22from mgplot.utilities import ( 23 apply_defaults, 24 get_color_list, 25 get_axes, 26 constrain_data, 27 default_rounding, 28) 29from mgplot.kw_type_checking import ( 30 ExpectedTypeDict, 31 validate_expected, 32 report_kwargs, 33 validate_kwargs, 34) 35from mgplot.axis_utils import set_labels, map_periodindex, is_categorical 36from mgplot.keyword_names import ( 37 AX, 38 STACKED, 39 ROTATION, 40 MAX_TICKS, 41 PLOT_FROM, 42 COLOR, 43 LABEL_SERIES, 44 WIDTH, 45 ANNOTATE, 46 FONTSIZE, 47 FONTNAME, 48 ROUNDING, 49 ANNOTATE_COLOR, 50 ABOVE, 51) 52 53 54# --- constants 55 56BAR_KW_TYPES: Final[ExpectedTypeDict] = { 57 # --- options for the entire bar plot 58 AX: (Axes, type(None)), # axes to plot on, or None for new axes 59 STACKED: bool, # if True, the bars will be stacked. If False, they will be grouped. 60 MAX_TICKS: int, 61 PLOT_FROM: (int, PeriodIndex, type(None)), 62 # --- options for each bar ... 63 COLOR: (str, Sequence, (str,)), 64 LABEL_SERIES: (bool, Sequence, (bool,)), 65 WIDTH: (float, int), 66 # - options for bar annotations 67 ANNOTATE: (type(None), bool), # None, True 68 FONTSIZE: (int, float, str), 69 FONTNAME: (str), 70 ROUNDING: int, 71 ROTATION: (int, float), # rotation of annotations in degrees 72 ANNOTATE_COLOR: (str, type(None)), # color of annotations 73 ABOVE: bool, # if True, annotations are above the bar 74} 75validate_expected(BAR_KW_TYPES, "bar_plot") 76 77 78# --- functions 79def annotate_bars( 80 series: Series, 81 offset: float, 82 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 83 axes: Axes, 84 **anno_kwargs, 85) -> None: 86 """Bar plot annotations. 87 88 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 89 """ 90 91 # --- only annotate in limited circumstances 92 if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]: 93 return 94 max_annotations = 30 95 if len(series) > max_annotations: 96 return 97 98 # --- internal logic check 99 if len(base) != len(series): 100 print( 101 f"Warning: base array length {len(base)} does not match series length {len(series)}." 102 ) 103 return 104 105 # --- assemble the annotation parameters 106 above: Final[bool | None] = anno_kwargs.get(ABOVE, False) # None is also False-ish 107 annotate_style = { 108 FONTSIZE: anno_kwargs.get(FONTSIZE), 109 FONTNAME: anno_kwargs.get(FONTNAME), 110 COLOR: anno_kwargs.get(COLOR), 111 ROTATION: anno_kwargs.get(ROTATION), 112 } 113 rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None)) 114 adjustment = (series.max() - series.min()) * 0.02 115 zero_correction = series.index.min() 116 117 # --- annotate each bar 118 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 119 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 120 if above: 121 position += value 122 text = axes.text( 123 x=index + offset, 124 y=position, 125 s=f"{value:.{rounding}f}", 126 ha="center", 127 va="bottom" if value >= 0 else "top", 128 **annotate_style, 129 ) 130 if not above and "foreground" in anno_kwargs: 131 # apply a stroke-effect to within bar annotations 132 # to make them more readable with very small bars. 133 text.set_path_effects( 134 [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))] 135 ) 136 137 138def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 139 """ 140 plot a grouped bar plot 141 """ 142 143 series_count = len(df.columns) 144 145 for i, col in enumerate(df.columns): 146 series = df[col] 147 if series.isnull().all(): 148 continue 149 width = kwargs["width"][i] 150 if width < 0 or width > 1: 151 width = 0.8 152 adjusted_width = width / series_count # 0.8 153 # far-left + margin + halfway through one grouped column 154 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 155 offset = left + (i * adjusted_width) 156 foreground = kwargs["color"][i] 157 axes.bar( 158 x=series.index + offset, 159 height=series, 160 color=foreground, 161 width=adjusted_width, 162 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 163 ) 164 annotate_bars( 165 series=series, 166 offset=offset, 167 base=np.zeros(len(series)), 168 axes=axes, 169 foreground=foreground, 170 **anno_args, 171 ) 172 173 174def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 175 """ 176 plot a stacked bar plot 177 """ 178 179 series_count = len(df) 180 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 181 shape=series_count, dtype=np.float64 182 ) 183 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 184 shape=series_count, dtype=np.float64 185 ) 186 for i, col in enumerate(df.columns): 187 series = df[col] 188 base = np.where(series >= 0, base_plus, base_minus) 189 foreground = kwargs["color"][i] 190 axes.bar( 191 x=series.index, 192 height=series, 193 bottom=base, 194 color=foreground, 195 width=kwargs[WIDTH][i], 196 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 197 ) 198 annotate_bars( 199 series=series, 200 offset=0, 201 base=base, 202 axes=axes, 203 foreground=foreground, 204 **anno_args 205 ) 206 base_plus += np.where(series >= 0, series, 0) 207 base_minus += np.where(series < 0, series, 0) 208 209 210def bar_plot( 211 data: DataT, 212 **kwargs, 213) -> Axes: 214 """ 215 Create a bar plot from the given data. Each column in the DataFrame 216 will be stacked on top of each other, with positive values above 217 zero and negative values below zero. 218 219 Parameters 220 - data: Series - The data to plot. Can be a DataFrame or a Series. 221 - **kwargs: dict Additional keyword arguments for customization. 222 # --- options for the entire bar plot 223 ax: Axes - axes to plot on, or None for new axes 224 stacked: bool - if True, the bars will be stacked. If False, they will be grouped. 225 max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) 226 plot_from: int | PeriodIndex - if provided, the plot will start from this index. 227 # --- options for each bar ... 228 color: str | list[str] - the color of the bars (or separate colors for each series 229 label_series: bool | list[bool] - if True, the series will be labeled in the legend 230 width: float | list[float] - the width of the bars 231 # - options for bar annotations 232 annotate: bool - If True them annotate the bars with their values. 233 fontsize: int | float | str - font size of the annotations 234 fontname: str - font name of the annotations 235 rounding: int - number of decimal places to round to 236 annotate_color: str - color of annotations 237 rotation: int | float - rotation of annotations in degrees 238 above: bool - if True, annotations are above the bar, else within the bar 239 240 Note: This function does not assume all data is timeseries with a PeriodIndex, 241 242 Returns 243 - axes: Axes - The axes for the plot. 244 """ 245 246 # --- check the kwargs 247 me = "bar_plot" 248 report_kwargs(called_from=me, **kwargs) 249 kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs) 250 251 # --- get the data 252 # no call to check_clean_timeseries here, as bar plots are not 253 # necessarily timeseries data. If the data is a Series, it will be 254 # converted to a DataFrame with a single column. 255 df = DataFrame(data) # really we are only plotting DataFrames 256 df, kwargs = constrain_data(df, **kwargs) 257 item_count = len(df.columns) 258 259 # --- deal with complete PeriodIdex indicies 260 if not is_categorical(df): 261 print( 262 "Warning: bar_plot is not designed for incomplete or non-categorical data indexes." 263 ) 264 saved_pi = map_periodindex(df) 265 if saved_pi is not None: 266 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 267 268 # --- set up the default arguments 269 chart_defaults: dict[str, Any] = { 270 STACKED: False, 271 MAX_TICKS: 10, 272 LABEL_SERIES: item_count > 1, 273 } 274 chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()} 275 276 bar_defaults: dict[str, Any] = { 277 COLOR: get_color_list(item_count), 278 WIDTH: get_setting("bar_width"), 279 LABEL_SERIES: (item_count > 1), 280 } 281 above = kwargs.get(ABOVE, False) 282 anno_args = { 283 ANNOTATE: kwargs.get(ANNOTATE, False), 284 FONTSIZE: kwargs.get(FONTSIZE, "small"), 285 FONTNAME: kwargs.get(FONTNAME, "Helvetica"), 286 ROTATION: kwargs.get(ROTATION, 0), 287 ROUNDING: kwargs.get(ROUNDING, True), 288 COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"), 289 ABOVE: above, 290 } 291 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs) 292 293 # --- plot the data 294 axes, _rkwargs = get_axes(**remaining_kwargs) 295 if chart_args[STACKED]: 296 stacked(axes, df, anno_args, **bar_args) 297 else: 298 grouped(axes, df, anno_args, **bar_args) 299 300 # --- handle complete periodIndex data and label rotation 301 rotate_labels = True 302 if saved_pi is not None: 303 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 304 rotate_labels = False 305 306 if rotate_labels: 307 plt.xticks(rotation=90) 308 309 return axes
80def annotate_bars( 81 series: Series, 82 offset: float, 83 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 84 axes: Axes, 85 **anno_kwargs, 86) -> None: 87 """Bar plot annotations. 88 89 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 90 """ 91 92 # --- only annotate in limited circumstances 93 if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]: 94 return 95 max_annotations = 30 96 if len(series) > max_annotations: 97 return 98 99 # --- internal logic check 100 if len(base) != len(series): 101 print( 102 f"Warning: base array length {len(base)} does not match series length {len(series)}." 103 ) 104 return 105 106 # --- assemble the annotation parameters 107 above: Final[bool | None] = anno_kwargs.get(ABOVE, False) # None is also False-ish 108 annotate_style = { 109 FONTSIZE: anno_kwargs.get(FONTSIZE), 110 FONTNAME: anno_kwargs.get(FONTNAME), 111 COLOR: anno_kwargs.get(COLOR), 112 ROTATION: anno_kwargs.get(ROTATION), 113 } 114 rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None)) 115 adjustment = (series.max() - series.min()) * 0.02 116 zero_correction = series.index.min() 117 118 # --- annotate each bar 119 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 120 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 121 if above: 122 position += value 123 text = axes.text( 124 x=index + offset, 125 y=position, 126 s=f"{value:.{rounding}f}", 127 ha="center", 128 va="bottom" if value >= 0 else "top", 129 **annotate_style, 130 ) 131 if not above and "foreground" in anno_kwargs: 132 # apply a stroke-effect to within bar annotations 133 # to make them more readable with very small bars. 134 text.set_path_effects( 135 [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))] 136 )
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
139def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 140 """ 141 plot a grouped bar plot 142 """ 143 144 series_count = len(df.columns) 145 146 for i, col in enumerate(df.columns): 147 series = df[col] 148 if series.isnull().all(): 149 continue 150 width = kwargs["width"][i] 151 if width < 0 or width > 1: 152 width = 0.8 153 adjusted_width = width / series_count # 0.8 154 # far-left + margin + halfway through one grouped column 155 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 156 offset = left + (i * adjusted_width) 157 foreground = kwargs["color"][i] 158 axes.bar( 159 x=series.index + offset, 160 height=series, 161 color=foreground, 162 width=adjusted_width, 163 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 164 ) 165 annotate_bars( 166 series=series, 167 offset=offset, 168 base=np.zeros(len(series)), 169 axes=axes, 170 foreground=foreground, 171 **anno_args, 172 )
plot a grouped bar plot
175def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 176 """ 177 plot a stacked bar plot 178 """ 179 180 series_count = len(df) 181 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 182 shape=series_count, dtype=np.float64 183 ) 184 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 185 shape=series_count, dtype=np.float64 186 ) 187 for i, col in enumerate(df.columns): 188 series = df[col] 189 base = np.where(series >= 0, base_plus, base_minus) 190 foreground = kwargs["color"][i] 191 axes.bar( 192 x=series.index, 193 height=series, 194 bottom=base, 195 color=foreground, 196 width=kwargs[WIDTH][i], 197 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 198 ) 199 annotate_bars( 200 series=series, 201 offset=0, 202 base=base, 203 axes=axes, 204 foreground=foreground, 205 **anno_args 206 ) 207 base_plus += np.where(series >= 0, series, 0) 208 base_minus += np.where(series < 0, series, 0)
plot a stacked bar plot
211def bar_plot( 212 data: DataT, 213 **kwargs, 214) -> Axes: 215 """ 216 Create a bar plot from the given data. Each column in the DataFrame 217 will be stacked on top of each other, with positive values above 218 zero and negative values below zero. 219 220 Parameters 221 - data: Series - The data to plot. Can be a DataFrame or a Series. 222 - **kwargs: dict Additional keyword arguments for customization. 223 # --- options for the entire bar plot 224 ax: Axes - axes to plot on, or None for new axes 225 stacked: bool - if True, the bars will be stacked. If False, they will be grouped. 226 max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) 227 plot_from: int | PeriodIndex - if provided, the plot will start from this index. 228 # --- options for each bar ... 229 color: str | list[str] - the color of the bars (or separate colors for each series 230 label_series: bool | list[bool] - if True, the series will be labeled in the legend 231 width: float | list[float] - the width of the bars 232 # - options for bar annotations 233 annotate: bool - If True them annotate the bars with their values. 234 fontsize: int | float | str - font size of the annotations 235 fontname: str - font name of the annotations 236 rounding: int - number of decimal places to round to 237 annotate_color: str - color of annotations 238 rotation: int | float - rotation of annotations in degrees 239 above: bool - if True, annotations are above the bar, else within the bar 240 241 Note: This function does not assume all data is timeseries with a PeriodIndex, 242 243 Returns 244 - axes: Axes - The axes for the plot. 245 """ 246 247 # --- check the kwargs 248 me = "bar_plot" 249 report_kwargs(called_from=me, **kwargs) 250 kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs) 251 252 # --- get the data 253 # no call to check_clean_timeseries here, as bar plots are not 254 # necessarily timeseries data. If the data is a Series, it will be 255 # converted to a DataFrame with a single column. 256 df = DataFrame(data) # really we are only plotting DataFrames 257 df, kwargs = constrain_data(df, **kwargs) 258 item_count = len(df.columns) 259 260 # --- deal with complete PeriodIdex indicies 261 if not is_categorical(df): 262 print( 263 "Warning: bar_plot is not designed for incomplete or non-categorical data indexes." 264 ) 265 saved_pi = map_periodindex(df) 266 if saved_pi is not None: 267 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 268 269 # --- set up the default arguments 270 chart_defaults: dict[str, Any] = { 271 STACKED: False, 272 MAX_TICKS: 10, 273 LABEL_SERIES: item_count > 1, 274 } 275 chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()} 276 277 bar_defaults: dict[str, Any] = { 278 COLOR: get_color_list(item_count), 279 WIDTH: get_setting("bar_width"), 280 LABEL_SERIES: (item_count > 1), 281 } 282 above = kwargs.get(ABOVE, False) 283 anno_args = { 284 ANNOTATE: kwargs.get(ANNOTATE, False), 285 FONTSIZE: kwargs.get(FONTSIZE, "small"), 286 FONTNAME: kwargs.get(FONTNAME, "Helvetica"), 287 ROTATION: kwargs.get(ROTATION, 0), 288 ROUNDING: kwargs.get(ROUNDING, True), 289 COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"), 290 ABOVE: above, 291 } 292 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs) 293 294 # --- plot the data 295 axes, _rkwargs = get_axes(**remaining_kwargs) 296 if chart_args[STACKED]: 297 stacked(axes, df, anno_args, **bar_args) 298 else: 299 grouped(axes, df, anno_args, **bar_args) 300 301 # --- handle complete periodIndex data and label rotation 302 rotate_labels = True 303 if saved_pi is not None: 304 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 305 rotate_labels = False 306 307 if rotate_labels: 308 plt.xticks(rotation=90) 309 310 return axes
Create a bar plot from the given data. Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Parameters
- data: Series - The data to plot. Can be a DataFrame or a Series.
- **kwargs: dict Additional keyword arguments for customization.
--- options for the entire bar plot
ax: Axes - axes to plot on, or None for new axes stacked: bool - if True, the bars will be stacked. If False, they will be grouped. max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) plot_from: int | PeriodIndex - if provided, the plot will start from this index.
--- options for each bar ...
color: str | list[str] - the color of the bars (or separate colors for each series label_series: bool | list[bool] - if True, the series will be labeled in the legend width: float | list[float] - the width of the bars
- options for bar annotations
annotate: bool - If True them annotate the bars with their values. fontsize: int | float | str - font size of the annotations fontname: str - font name of the annotations rounding: int - number of decimal places to round to annotate_color: str - color of annotations rotation: int | float - rotation of annotations in degrees above: bool - if True, annotations are above the bar, else within the bar
Note: This function does not assume all data is timeseries with a PeriodIndex,
Returns
- axes: Axes - The axes for the plot.