mgplot.bar_plot
bar_plot.py This module contains functions to create bar plots using Matplotlib. Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1""" 2bar_plot.py 3This module contains functions to create bar plots using Matplotlib. 4Note: bar plots in Matplotlib are not the same as bar charts in other 5libraries. Bar plots are used to represent categorical data with 6rectangular bars. As a result, bar plots and line plots typically 7cannot be plotted on the same axes. 8""" 9 10# --- imports 11from typing import Any, Final, Unpack, NotRequired 12from collections.abc import Sequence 13 14import numpy as np 15from pandas import Series, DataFrame, Period 16import matplotlib.pyplot as plt 17from matplotlib.pyplot import Axes 18import matplotlib.patheffects as pe 19 20 21from mgplot.settings import DataT, get_setting 22from mgplot.utilities import ( 23 apply_defaults, 24 get_color_list, 25 get_axes, 26 constrain_data, 27 default_rounding, 28) 29from mgplot.keyword_checking import validate_kwargs, report_kwargs, BaseKwargs 30from mgplot.axis_utils import set_labels, map_periodindex, is_categorical 31 32 33# --- constants 34ME: Final[str] = "bar_plot" 35 36 37class BarKwargs(BaseKwargs): 38 """Keyword arguments for the bar_plot function.""" 39 40 # --- options for the entire bar plot 41 ax: NotRequired[Axes | None] 42 stacked: NotRequired[bool] 43 max_ticks: NotRequired[int] 44 plot_from: NotRequired[int | Period] 45 # --- options for each bar ... 46 color: NotRequired[str | Sequence[str]] 47 label_series: NotRequired[bool | Sequence[bool]] 48 width: NotRequired[float | int | Sequence[float | int]] 49 # --- options for bar annotations 50 annotate: NotRequired[bool] 51 fontsize: NotRequired[int | float | str] 52 fontname: NotRequired[str] 53 rounding: NotRequired[int] 54 rotation: NotRequired[int | float] 55 annotate_color: NotRequired[str] 56 above: NotRequired[bool] 57 58 59# --- functions 60def annotate_bars( 61 series: Series, 62 offset: float, 63 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 64 axes: Axes, 65 **anno_kwargs, 66) -> None: 67 """Bar plot annotations. 68 69 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 70 """ 71 72 # --- only annotate in limited circumstances 73 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 74 return 75 max_annotations = 30 76 if len(series) > max_annotations: 77 return 78 79 # --- internal logic check 80 if len(base) != len(series): 81 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 82 return 83 84 # --- assemble the annotation parameters 85 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 86 annotate_style = { 87 "fontsize": anno_kwargs.get("fontsize"), 88 "fontname": anno_kwargs.get("fontname"), 89 "color": anno_kwargs.get("color"), 90 "rotation": anno_kwargs.get("rotation"), 91 } 92 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding", None)) 93 adjustment = (series.max() - series.min()) * 0.02 94 zero_correction = series.index.min() 95 96 # --- annotate each bar 97 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 98 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 99 if above: 100 position += value 101 text = axes.text( 102 x=index + offset, 103 y=position, 104 s=f"{value:.{rounding}f}", 105 ha="center", 106 va="bottom" if value >= 0 else "top", 107 **annotate_style, 108 ) 109 if not above and "foreground" in anno_kwargs: 110 # apply a stroke-effect to within bar annotations 111 # to make them more readable with very small bars. 112 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]) 113 114 115def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 116 """ 117 plot a grouped bar plot 118 """ 119 120 series_count = len(df.columns) 121 122 for i, col in enumerate(df.columns): 123 series = df[col] 124 if series.isnull().all(): 125 continue 126 width = kwargs["width"][i] 127 if width < 0 or width > 1: 128 width = 0.8 129 adjusted_width = width / series_count # 0.8 130 # far-left + margin + halfway through one grouped column 131 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 132 offset = left + (i * adjusted_width) 133 foreground = kwargs["color"][i] 134 axes.bar( 135 x=series.index + offset, 136 height=series, 137 color=foreground, 138 width=adjusted_width, 139 label=col if kwargs["label_series"][i] else f"_{col}_", 140 ) 141 annotate_bars( 142 series=series, 143 offset=offset, 144 base=np.zeros(len(series)), 145 axes=axes, 146 foreground=foreground, 147 **anno_args, 148 ) 149 150 151def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 152 """ 153 plot a stacked bar plot 154 """ 155 156 series_count = len(df) 157 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 158 shape=series_count, dtype=np.float64 159 ) 160 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 161 shape=series_count, dtype=np.float64 162 ) 163 for i, col in enumerate(df.columns): 164 series = df[col] 165 base = np.where(series >= 0, base_plus, base_minus) 166 foreground = kwargs["color"][i] 167 axes.bar( 168 x=series.index, 169 height=series, 170 bottom=base, 171 color=foreground, 172 width=kwargs["width"][i], 173 label=col if kwargs["label_series"][i] else f"_{col}_", 174 ) 175 annotate_bars( 176 series=series, 177 offset=0, 178 base=base, 179 axes=axes, 180 foreground=foreground, 181 **anno_args, 182 ) 183 base_plus += np.where(series >= 0, series, 0) 184 base_minus += np.where(series < 0, series, 0) 185 186 187def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 188 """ 189 Create a bar plot from the given data. Each column in the DataFrame 190 will be stacked on top of each other, with positive values above 191 zero and negative values below zero. 192 193 Parameters 194 - data: Series - The data to plot. Can be a DataFrame or a Series. 195 - **kwargs: BarKwargs - Additional keyword arguments for customization. 196 (see BarKwargs for details) 197 198 Note: This function does not assume all data is timeseries with a PeriodIndex, 199 200 Returns 201 - axes: Axes - The axes for the plot. 202 """ 203 204 # --- check the kwargs 205 report_kwargs(caller=ME, **kwargs) 206 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 207 208 # --- get the data 209 # no call to check_clean_timeseries here, as bar plots are not 210 # necessarily timeseries data. If the data is a Series, it will be 211 # converted to a DataFrame with a single column. 212 df = DataFrame(data) # really we are only plotting DataFrames 213 df, kwargs_d = constrain_data(df, **kwargs) 214 item_count = len(df.columns) 215 216 # --- deal with complete PeriodIdex indicies 217 if not is_categorical(df): 218 print("Warning: bar_plot is not designed for incomplete or non-categorical data indexes.") 219 saved_pi = map_periodindex(df) 220 if saved_pi is not None: 221 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 222 223 # --- set up the default arguments 224 chart_defaults: dict[str, Any] = { 225 "stacked": False, 226 "max_ticks": 10, 227 "label_series": item_count > 1, 228 } 229 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 230 231 bar_defaults: dict[str, Any] = { 232 "color": get_color_list(item_count), 233 "width": get_setting("bar_width"), 234 "label_series": (item_count > 1), 235 } 236 above = kwargs_d.get("above", False) 237 anno_args = { 238 "annotate": kwargs_d.get("annotate", False), 239 "fontsize": kwargs_d.get("fontsize", "small"), 240 "fontname": kwargs_d.get("fontname", "Helvetica"), 241 "rotation": kwargs_d.get("rotation", 0), 242 "rounding": kwargs_d.get("rounding", True), 243 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 244 "above": above, 245 } 246 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 247 248 # --- plot the data 249 axes, remaining_kwargs = get_axes(**remaining_kwargs) 250 if chart_args["stacked"]: 251 stacked(axes, df, anno_args, **bar_args) 252 else: 253 grouped(axes, df, anno_args, **bar_args) 254 255 # --- handle complete periodIndex data and label rotation 256 if saved_pi is not None: 257 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 258 else: 259 plt.xticks(rotation=90) 260 261 return axes
38class BarKwargs(BaseKwargs): 39 """Keyword arguments for the bar_plot function.""" 40 41 # --- options for the entire bar plot 42 ax: NotRequired[Axes | None] 43 stacked: NotRequired[bool] 44 max_ticks: NotRequired[int] 45 plot_from: NotRequired[int | Period] 46 # --- options for each bar ... 47 color: NotRequired[str | Sequence[str]] 48 label_series: NotRequired[bool | Sequence[bool]] 49 width: NotRequired[float | int | Sequence[float | int]] 50 # --- options for bar annotations 51 annotate: NotRequired[bool] 52 fontsize: NotRequired[int | float | str] 53 fontname: NotRequired[str] 54 rounding: NotRequired[int] 55 rotation: NotRequired[int | float] 56 annotate_color: NotRequired[str] 57 above: NotRequired[bool]
Keyword arguments for the bar_plot function.
61def annotate_bars( 62 series: Series, 63 offset: float, 64 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 65 axes: Axes, 66 **anno_kwargs, 67) -> None: 68 """Bar plot annotations. 69 70 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 71 """ 72 73 # --- only annotate in limited circumstances 74 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 75 return 76 max_annotations = 30 77 if len(series) > max_annotations: 78 return 79 80 # --- internal logic check 81 if len(base) != len(series): 82 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 83 return 84 85 # --- assemble the annotation parameters 86 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 87 annotate_style = { 88 "fontsize": anno_kwargs.get("fontsize"), 89 "fontname": anno_kwargs.get("fontname"), 90 "color": anno_kwargs.get("color"), 91 "rotation": anno_kwargs.get("rotation"), 92 } 93 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding", None)) 94 adjustment = (series.max() - series.min()) * 0.02 95 zero_correction = series.index.min() 96 97 # --- annotate each bar 98 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 99 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 100 if above: 101 position += value 102 text = axes.text( 103 x=index + offset, 104 y=position, 105 s=f"{value:.{rounding}f}", 106 ha="center", 107 va="bottom" if value >= 0 else "top", 108 **annotate_style, 109 ) 110 if not above and "foreground" in anno_kwargs: 111 # apply a stroke-effect to within bar annotations 112 # to make them more readable with very small bars. 113 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
116def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 117 """ 118 plot a grouped bar plot 119 """ 120 121 series_count = len(df.columns) 122 123 for i, col in enumerate(df.columns): 124 series = df[col] 125 if series.isnull().all(): 126 continue 127 width = kwargs["width"][i] 128 if width < 0 or width > 1: 129 width = 0.8 130 adjusted_width = width / series_count # 0.8 131 # far-left + margin + halfway through one grouped column 132 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 133 offset = left + (i * adjusted_width) 134 foreground = kwargs["color"][i] 135 axes.bar( 136 x=series.index + offset, 137 height=series, 138 color=foreground, 139 width=adjusted_width, 140 label=col if kwargs["label_series"][i] else f"_{col}_", 141 ) 142 annotate_bars( 143 series=series, 144 offset=offset, 145 base=np.zeros(len(series)), 146 axes=axes, 147 foreground=foreground, 148 **anno_args, 149 )
plot a grouped bar plot
152def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 153 """ 154 plot a stacked bar plot 155 """ 156 157 series_count = len(df) 158 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 159 shape=series_count, dtype=np.float64 160 ) 161 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 162 shape=series_count, dtype=np.float64 163 ) 164 for i, col in enumerate(df.columns): 165 series = df[col] 166 base = np.where(series >= 0, base_plus, base_minus) 167 foreground = kwargs["color"][i] 168 axes.bar( 169 x=series.index, 170 height=series, 171 bottom=base, 172 color=foreground, 173 width=kwargs["width"][i], 174 label=col if kwargs["label_series"][i] else f"_{col}_", 175 ) 176 annotate_bars( 177 series=series, 178 offset=0, 179 base=base, 180 axes=axes, 181 foreground=foreground, 182 **anno_args, 183 ) 184 base_plus += np.where(series >= 0, series, 0) 185 base_minus += np.where(series < 0, series, 0)
plot a stacked bar plot
188def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 189 """ 190 Create a bar plot from the given data. Each column in the DataFrame 191 will be stacked on top of each other, with positive values above 192 zero and negative values below zero. 193 194 Parameters 195 - data: Series - The data to plot. Can be a DataFrame or a Series. 196 - **kwargs: BarKwargs - Additional keyword arguments for customization. 197 (see BarKwargs for details) 198 199 Note: This function does not assume all data is timeseries with a PeriodIndex, 200 201 Returns 202 - axes: Axes - The axes for the plot. 203 """ 204 205 # --- check the kwargs 206 report_kwargs(caller=ME, **kwargs) 207 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 208 209 # --- get the data 210 # no call to check_clean_timeseries here, as bar plots are not 211 # necessarily timeseries data. If the data is a Series, it will be 212 # converted to a DataFrame with a single column. 213 df = DataFrame(data) # really we are only plotting DataFrames 214 df, kwargs_d = constrain_data(df, **kwargs) 215 item_count = len(df.columns) 216 217 # --- deal with complete PeriodIdex indicies 218 if not is_categorical(df): 219 print("Warning: bar_plot is not designed for incomplete or non-categorical data indexes.") 220 saved_pi = map_periodindex(df) 221 if saved_pi is not None: 222 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 223 224 # --- set up the default arguments 225 chart_defaults: dict[str, Any] = { 226 "stacked": False, 227 "max_ticks": 10, 228 "label_series": item_count > 1, 229 } 230 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 231 232 bar_defaults: dict[str, Any] = { 233 "color": get_color_list(item_count), 234 "width": get_setting("bar_width"), 235 "label_series": (item_count > 1), 236 } 237 above = kwargs_d.get("above", False) 238 anno_args = { 239 "annotate": kwargs_d.get("annotate", False), 240 "fontsize": kwargs_d.get("fontsize", "small"), 241 "fontname": kwargs_d.get("fontname", "Helvetica"), 242 "rotation": kwargs_d.get("rotation", 0), 243 "rounding": kwargs_d.get("rounding", True), 244 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 245 "above": above, 246 } 247 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 248 249 # --- plot the data 250 axes, remaining_kwargs = get_axes(**remaining_kwargs) 251 if chart_args["stacked"]: 252 stacked(axes, df, anno_args, **bar_args) 253 else: 254 grouped(axes, df, anno_args, **bar_args) 255 256 # --- handle complete periodIndex data and label rotation 257 if saved_pi is not None: 258 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 259 else: 260 plt.xticks(rotation=90) 261 262 return axes
Create a bar plot from the given data. Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Parameters
- data: Series - The data to plot. Can be a DataFrame or a Series.
- **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)
Note: This function does not assume all data is timeseries with a PeriodIndex,
Returns
- axes: Axes - The axes for the plot.