mgplot.bar_plot
Create bar plots using Matplotlib.
Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1"""Create bar plots using Matplotlib. 2 3Note: bar plots in Matplotlib are not the same as bar charts in other 4libraries. Bar plots are used to represent categorical data with 5rectangular bars. As a result, bar plots and line plots typically 6cannot be plotted on the same axes. 7""" 8 9from collections.abc import Sequence 10from typing import Any, Final, NotRequired, TypedDict, Unpack 11 12import matplotlib.patheffects as pe 13import matplotlib.pyplot as plt 14import numpy as np 15from matplotlib.axes import Axes 16from pandas import DataFrame, Period, Series 17 18from mgplot.axis_utils import map_periodindex, set_labels 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs 20from mgplot.settings import DataT, get_setting 21from mgplot.utilities import ( 22 apply_defaults, 23 constrain_data, 24 default_rounding, 25 get_axes, 26 get_color_list, 27) 28 29# --- constants 30ME: Final[str] = "bar_plot" 31MAX_ANNOTATIONS: Final[int] = 30 32ADJUSTMENT_FACTOR: Final[float] = 0.02 33MIN_BAR_WIDTH: Final[float] = 0.0 34MAX_BAR_WIDTH: Final[float] = 1.0 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8 36DEFAULT_BAR_OFFSET: Final[float] = 0.5 37DEFAULT_MAX_TICKS: Final[int] = 10 38 39 40class BarKwargs(BaseKwargs): 41 """Keyword arguments for the bar_plot function.""" 42 43 # --- options for the entire bar plot 44 ax: NotRequired[Axes | None] 45 stacked: NotRequired[bool] 46 max_ticks: NotRequired[int] 47 plot_from: NotRequired[int | Period] 48 # --- options for each bar ... 49 color: NotRequired[str | Sequence[str]] 50 label_series: NotRequired[bool | Sequence[bool]] 51 width: NotRequired[float | int | Sequence[float | int]] 52 # --- options for bar annotations 53 annotate: NotRequired[bool] 54 fontsize: NotRequired[int | float | str] 55 fontname: NotRequired[str] 56 rounding: NotRequired[int] 57 rotation: NotRequired[int | float] 58 annotate_color: NotRequired[str] 59 above: NotRequired[bool] 60 61 62# --- functions 63class AnnoKwargs(TypedDict, total=False): 64 """TypedDict for the kwargs used in annotate_bars.""" 65 66 annotate: bool 67 fontsize: int | float | str 68 fontname: str 69 color: str 70 rotation: int | float 71 foreground: str # used for stroke effect on text 72 above: bool 73 rounding: bool | int # if True, uses default rounding; if int, uses that value 74 75 76def annotate_bars( 77 series: Series, 78 offset: float, 79 base: np.ndarray, 80 axes: Axes, 81 **anno_kwargs: Unpack[AnnoKwargs], 82) -> None: 83 """Bar plot annotations. 84 85 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 86 """ 87 # --- only annotate in limited circumstances 88 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 89 return 90 max_annotations = MAX_ANNOTATIONS 91 if len(series) > max_annotations: 92 return 93 94 # --- internal logic check 95 if len(base) != len(series): 96 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 97 return 98 99 # --- assemble the annotation parameters 100 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 101 annotate_style: dict[str, Any] = { 102 "fontsize": anno_kwargs.get("fontsize"), 103 "fontname": anno_kwargs.get("fontname"), 104 "color": anno_kwargs.get("color"), 105 "rotation": anno_kwargs.get("rotation"), 106 } 107 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 108 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 109 zero_correction = series.index.min() 110 111 # --- annotate each bar 112 for index, value in zip(series.index.astype(int), series, strict=True): 113 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 114 if above: 115 position += value 116 text = axes.text( 117 x=index + offset, 118 y=position, 119 s=f"{value:.{rounding}f}", 120 ha="center", 121 va="bottom" if value >= 0 else "top", 122 **annotate_style, 123 ) 124 if not above and "foreground" in anno_kwargs: 125 # apply a stroke-effect to within bar annotations 126 # to make them more readable with very small bars. 127 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]) 128 129 130class GroupedKwargs(TypedDict): 131 """TypedDict for the kwargs used in grouped.""" 132 133 color: Sequence[str] 134 width: Sequence[float | int] 135 label_series: Sequence[bool] 136 137 138def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 139 """Plot a grouped bar plot.""" 140 series_count = len(df.columns) 141 142 for i, col in enumerate(df.columns): 143 series = df[col] 144 if series.isna().all(): 145 continue 146 width = kwargs["width"][i] 147 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 148 width = DEFAULT_GROUPED_WIDTH 149 adjusted_width = width / series_count 150 # far-left + margin + halfway through one grouped column 151 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 152 offset = left + (i * adjusted_width) 153 foreground = kwargs["color"][i] 154 axes.bar( 155 x=series.index + offset, 156 height=series, 157 color=foreground, 158 width=adjusted_width, 159 label=col if kwargs["label_series"][i] else f"_{col}_", 160 ) 161 anno_args["foreground"] = foreground 162 annotate_bars( 163 series=series, 164 offset=offset, 165 base=np.zeros(len(series)), 166 axes=axes, 167 **anno_args, 168 ) 169 170 171class StackedKwargs(TypedDict): 172 """TypedDict for the kwargs used in stacked.""" 173 174 color: Sequence[str] 175 width: Sequence[float | int] 176 label_series: Sequence[bool] 177 178 179def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 180 """Plot a stacked bar plot.""" 181 row_count = len(df) 182 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 183 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 184 for i, col in enumerate(df.columns): 185 series = df[col] 186 base = np.where(series >= 0, base_plus, base_minus) 187 foreground = kwargs["color"][i] 188 axes.bar( 189 x=series.index, 190 height=series, 191 bottom=base, 192 color=foreground, 193 width=kwargs["width"][i], 194 label=col if kwargs["label_series"][i] else f"_{col}_", 195 ) 196 anno_args["foreground"] = foreground 197 annotate_bars( 198 series=series, 199 offset=0, 200 base=base, 201 axes=axes, 202 **anno_args, 203 ) 204 base_plus += np.where(series >= 0, series, 0) 205 base_minus += np.where(series < 0, series, 0) 206 207 208def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 209 """Create a bar plot from the given data. 210 211 Each column in the DataFrame will be stacked on top of each other, 212 with positive values above zero and negative values below zero. 213 214 Args: 215 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 216 **kwargs: BarKwargs - Additional keyword arguments for customization. 217 (see BarKwargs for details) 218 219 Note: This function does not assume all data is timeseries with a PeriodIndex. 220 221 Returns: 222 axes: Axes - The axes for the plot. 223 224 """ 225 # --- check the kwargs 226 report_kwargs(caller=ME, **kwargs) 227 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 228 229 # --- get the data 230 # no call to check_clean_timeseries here, as bar plots are not 231 # necessarily timeseries data. If the data is a Series, it will be 232 # converted to a DataFrame with a single column. 233 df = DataFrame(data) # really we are only plotting DataFrames 234 df, kwargs_d = constrain_data(df, **kwargs) 235 item_count = len(df.columns) 236 237 # --- deal with complete PeriodIndex indices 238 saved_pi = map_periodindex(df) 239 if saved_pi is not None: 240 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 241 242 # --- set up the default arguments 243 chart_defaults: dict[str, bool | int] = { 244 "stacked": False, 245 "max_ticks": DEFAULT_MAX_TICKS, 246 "label_series": item_count > 1, 247 } 248 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 249 250 bar_defaults = { 251 "color": get_color_list(item_count), 252 "width": get_setting("bar_width"), 253 "label_series": item_count > 1, 254 } 255 above = kwargs_d.get("above", False) 256 anno_args: AnnoKwargs = { 257 "annotate": kwargs_d.get("annotate", False), 258 "fontsize": kwargs_d.get("fontsize", "small"), 259 "fontname": kwargs_d.get("fontname", "Helvetica"), 260 "rotation": kwargs_d.get("rotation", 0), 261 "rounding": kwargs_d.get("rounding", True), 262 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 263 "above": above, 264 } 265 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 266 267 # --- plot the data 268 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 269 if chart_args["stacked"]: 270 stacked(axes, df, anno_args, **bar_args) 271 else: 272 grouped(axes, df, anno_args, **bar_args) 273 274 # --- handle complete periodIndex data and label rotation 275 if saved_pi is not None: 276 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 277 else: 278 plt.xticks(rotation=90) 279 280 return axes
41class BarKwargs(BaseKwargs): 42 """Keyword arguments for the bar_plot function.""" 43 44 # --- options for the entire bar plot 45 ax: NotRequired[Axes | None] 46 stacked: NotRequired[bool] 47 max_ticks: NotRequired[int] 48 plot_from: NotRequired[int | Period] 49 # --- options for each bar ... 50 color: NotRequired[str | Sequence[str]] 51 label_series: NotRequired[bool | Sequence[bool]] 52 width: NotRequired[float | int | Sequence[float | int]] 53 # --- options for bar annotations 54 annotate: NotRequired[bool] 55 fontsize: NotRequired[int | float | str] 56 fontname: NotRequired[str] 57 rounding: NotRequired[int] 58 rotation: NotRequired[int | float] 59 annotate_color: NotRequired[str] 60 above: NotRequired[bool]
Keyword arguments for the bar_plot function.
64class AnnoKwargs(TypedDict, total=False): 65 """TypedDict for the kwargs used in annotate_bars.""" 66 67 annotate: bool 68 fontsize: int | float | str 69 fontname: str 70 color: str 71 rotation: int | float 72 foreground: str # used for stroke effect on text 73 above: bool 74 rounding: bool | int # if True, uses default rounding; if int, uses that value
TypedDict for the kwargs used in annotate_bars.
77def annotate_bars( 78 series: Series, 79 offset: float, 80 base: np.ndarray, 81 axes: Axes, 82 **anno_kwargs: Unpack[AnnoKwargs], 83) -> None: 84 """Bar plot annotations. 85 86 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 87 """ 88 # --- only annotate in limited circumstances 89 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 90 return 91 max_annotations = MAX_ANNOTATIONS 92 if len(series) > max_annotations: 93 return 94 95 # --- internal logic check 96 if len(base) != len(series): 97 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 98 return 99 100 # --- assemble the annotation parameters 101 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 102 annotate_style: dict[str, Any] = { 103 "fontsize": anno_kwargs.get("fontsize"), 104 "fontname": anno_kwargs.get("fontname"), 105 "color": anno_kwargs.get("color"), 106 "rotation": anno_kwargs.get("rotation"), 107 } 108 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 109 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 110 zero_correction = series.index.min() 111 112 # --- annotate each bar 113 for index, value in zip(series.index.astype(int), series, strict=True): 114 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 115 if above: 116 position += value 117 text = axes.text( 118 x=index + offset, 119 y=position, 120 s=f"{value:.{rounding}f}", 121 ha="center", 122 va="bottom" if value >= 0 else "top", 123 **annotate_style, 124 ) 125 if not above and "foreground" in anno_kwargs: 126 # apply a stroke-effect to within bar annotations 127 # to make them more readable with very small bars. 128 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
131class GroupedKwargs(TypedDict): 132 """TypedDict for the kwargs used in grouped.""" 133 134 color: Sequence[str] 135 width: Sequence[float | int] 136 label_series: Sequence[bool]
TypedDict for the kwargs used in grouped.
139def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 140 """Plot a grouped bar plot.""" 141 series_count = len(df.columns) 142 143 for i, col in enumerate(df.columns): 144 series = df[col] 145 if series.isna().all(): 146 continue 147 width = kwargs["width"][i] 148 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 149 width = DEFAULT_GROUPED_WIDTH 150 adjusted_width = width / series_count 151 # far-left + margin + halfway through one grouped column 152 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 153 offset = left + (i * adjusted_width) 154 foreground = kwargs["color"][i] 155 axes.bar( 156 x=series.index + offset, 157 height=series, 158 color=foreground, 159 width=adjusted_width, 160 label=col if kwargs["label_series"][i] else f"_{col}_", 161 ) 162 anno_args["foreground"] = foreground 163 annotate_bars( 164 series=series, 165 offset=offset, 166 base=np.zeros(len(series)), 167 axes=axes, 168 **anno_args, 169 )
Plot a grouped bar plot.
172class StackedKwargs(TypedDict): 173 """TypedDict for the kwargs used in stacked.""" 174 175 color: Sequence[str] 176 width: Sequence[float | int] 177 label_series: Sequence[bool]
TypedDict for the kwargs used in stacked.
180def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 181 """Plot a stacked bar plot.""" 182 row_count = len(df) 183 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 184 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 185 for i, col in enumerate(df.columns): 186 series = df[col] 187 base = np.where(series >= 0, base_plus, base_minus) 188 foreground = kwargs["color"][i] 189 axes.bar( 190 x=series.index, 191 height=series, 192 bottom=base, 193 color=foreground, 194 width=kwargs["width"][i], 195 label=col if kwargs["label_series"][i] else f"_{col}_", 196 ) 197 anno_args["foreground"] = foreground 198 annotate_bars( 199 series=series, 200 offset=0, 201 base=base, 202 axes=axes, 203 **anno_args, 204 ) 205 base_plus += np.where(series >= 0, series, 0) 206 base_minus += np.where(series < 0, series, 0)
Plot a stacked bar plot.
209def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 210 """Create a bar plot from the given data. 211 212 Each column in the DataFrame will be stacked on top of each other, 213 with positive values above zero and negative values below zero. 214 215 Args: 216 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 217 **kwargs: BarKwargs - Additional keyword arguments for customization. 218 (see BarKwargs for details) 219 220 Note: This function does not assume all data is timeseries with a PeriodIndex. 221 222 Returns: 223 axes: Axes - The axes for the plot. 224 225 """ 226 # --- check the kwargs 227 report_kwargs(caller=ME, **kwargs) 228 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 229 230 # --- get the data 231 # no call to check_clean_timeseries here, as bar plots are not 232 # necessarily timeseries data. If the data is a Series, it will be 233 # converted to a DataFrame with a single column. 234 df = DataFrame(data) # really we are only plotting DataFrames 235 df, kwargs_d = constrain_data(df, **kwargs) 236 item_count = len(df.columns) 237 238 # --- deal with complete PeriodIndex indices 239 saved_pi = map_periodindex(df) 240 if saved_pi is not None: 241 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 242 243 # --- set up the default arguments 244 chart_defaults: dict[str, bool | int] = { 245 "stacked": False, 246 "max_ticks": DEFAULT_MAX_TICKS, 247 "label_series": item_count > 1, 248 } 249 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 250 251 bar_defaults = { 252 "color": get_color_list(item_count), 253 "width": get_setting("bar_width"), 254 "label_series": item_count > 1, 255 } 256 above = kwargs_d.get("above", False) 257 anno_args: AnnoKwargs = { 258 "annotate": kwargs_d.get("annotate", False), 259 "fontsize": kwargs_d.get("fontsize", "small"), 260 "fontname": kwargs_d.get("fontname", "Helvetica"), 261 "rotation": kwargs_d.get("rotation", 0), 262 "rounding": kwargs_d.get("rounding", True), 263 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 264 "above": above, 265 } 266 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 267 268 # --- plot the data 269 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 270 if chart_args["stacked"]: 271 stacked(axes, df, anno_args, **bar_args) 272 else: 273 grouped(axes, df, anno_args, **bar_args) 274 275 # --- handle complete periodIndex data and label rotation 276 if saved_pi is not None: 277 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 278 else: 279 plt.xticks(rotation=90) 280 281 return axes
Create a bar plot from the given data.
Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)
Note: This function does not assume all data is timeseries with a PeriodIndex.
Returns: axes: Axes - The axes for the plot.