diff --git a/docs/examples/plot_types/12_taylor_diagram.py b/docs/examples/plot_types/12_taylor_diagram.py new file mode 100644 index 000000000..487a58c4c --- /dev/null +++ b/docs/examples/plot_types/12_taylor_diagram.py @@ -0,0 +1,72 @@ +""" +Taylor Diagram +============== + +Taylor diagrams compare model skill with correlation coefficient, standard +deviation, and centered RMS difference in a single polar-style plot. + +Why UltraPlot here? +------------------- +UltraPlot exposes Taylor diagrams as a projection, so you can create them with +``proj="taylor"`` and then use regular axes methods plus convenience methods +for plotting points from correlation and standard-deviation coordinates. + +Key functions: :py:meth:`ultraplot.axes.TaylorAxes.plot_corr`, +:py:meth:`ultraplot.axes.TaylorAxes.scatter_corr`. + +See also +-------- +* :doc:`Geographic and polar axes ` +""" + +import numpy as np + +import ultraplot as uplt + +models = ("Control", "Physics A", "Physics B", "Ensemble") +correlation = np.array([0.73, 0.84, 0.91, 0.96]) +stddev = np.array([0.82, 1.18, 1.05, 0.93]) +colors = ("blue7", "orange7", "green7", "violet7") + +fig, ax = uplt.subplots(proj="taylor", refwidth=4.2) +ax.format( + title="Model skill summary", + xlabel="Standard deviation", + ylabel="", + corrlabel="Correlation", + rlim=(0, 1.5), + rlines=0.25, + corrlines=(1, 0.95, 0.9, 0.8, 0.6, 0.4, 0.2, 0), +) + +# Centered RMS-difference contours around the reference point at (corr=1, std=1). +theta = np.linspace(0, np.pi / 2, 160) +radius = np.linspace(0, 1.5, 160) +theta_grid, radius_grid = np.meshgrid(theta, radius) +rms = np.sqrt(1 + radius_grid**2 - 2 * radius_grid * np.cos(theta_grid)) +contours = ax.contour( + theta_grid, + radius_grid, + rms, + levels=(0.25, 0.5, 0.75, 1.0, 1.25), + cmap="tokyo", + lw=0.9, + ls="--", +) +ax.clabel(contours, levels=(0.5, 1.0), inline=True, fontsize=8, fmt="%.1f") + +ax.plot_corr(1, 1, marker="*", markersize=12, color="red7", label="Reference") +for name, corr, std, color in zip(models, correlation, stddev, colors): + ax.scatter_corr( + corr, + std, + s=75, + color=color, + edgecolor="white", + lw=0.8, + zorder=4, + label=name, + ) + +ax.legend(loc="b", ncols=3, frame=False) +fig.show() diff --git a/docs/projections.py b/docs/projections.py index 24e285c76..14d5c4c00 100644 --- a/docs/projections.py +++ b/docs/projections.py @@ -108,6 +108,71 @@ rlocator=2, ) +# %% [raw] raw_mimetype="text/restructuredtext" +# .. _ug_taylor: +# +# Taylor diagrams +# --------------- +# +# To create Taylor diagrams, pass ``proj='taylor'`` to an axes-creation command. +# Taylor axes are represented with the :class:`~ultraplot.axes.TaylorAxes` +# subclass and are useful for comparing model output against a reference series: +# angular gridlines represent correlation coefficients and radial gridlines +# represent standard deviation. The :meth:`~ultraplot.axes.TaylorAxes.plot_corr` +# and :meth:`~ultraplot.axes.TaylorAxes.scatter_corr` helpers accept correlation +# and standard-deviation coordinates directly, while regular polar plotting +# commands can still be used for annotations like centered RMS-difference +# contours. + +# %% +import numpy as np + +import ultraplot as uplt + +labels = ("Model A", "Model B", "Model C", "Model D") +corr = np.array([0.72, 0.84, 0.91, 0.96]) +std = np.array([0.86, 1.16, 1.04, 0.94]) +colors = ("blue7", "orange7", "green7", "violet7") + +fig, ax = uplt.subplots(proj="taylor", refwidth=4.2) +ax.format( + title="Taylor diagram", + xlabel="Standard deviation", + ylabel="", + rlim=(0, 1.5), + rlines=0.25, + corrlines=(1, 0.95, 0.9, 0.8, 0.6, 0.4, 0.2, 0), +) + +theta = np.linspace(0, np.pi / 2, 121) +radius = np.linspace(0, 1.5, 121) +theta_grid, radius_grid = np.meshgrid(theta, radius) +rms = np.sqrt(1 + radius_grid**2 - 2 * radius_grid * np.cos(theta_grid)) +contours = ax.contour( + theta_grid, + radius_grid, + rms, + levels=(0.25, 0.5, 0.75, 1.0, 1.25), + cmap="tokyo", + lw=0.9, + ls="--", +) +ax.clabel(contours, levels=(0.5, 1.0), inline=True, fontsize=8, fmt="%.1f") + +ax.plot_corr(1, 1, marker="*", markersize=11, color="red7", label="Reference") +for label, c, s, color in zip(labels, corr, std, colors): + ax.scatter_corr( + c, + s, + s=70, + color=color, + edgecolor="white", + lw=0.8, + zorder=4, + label=label, + ) +ax.legend(loc="b", ncols=3, frame=False) + # %% [raw] raw_mimetype="text/restructuredtext" # .. _ug_geo: # diff --git a/ultraplot/__init__.py b/ultraplot/__init__.py index 074c87938..1aa62d0b0 100644 --- a/ultraplot/__init__.py +++ b/ultraplot/__init__.py @@ -26,6 +26,7 @@ from .axes import GeoAxes as GeoAxes from .axes import PlotAxes as PlotAxes from .axes import PolarAxes as PolarAxes + from .axes import TaylorAxes as TaylorAxes from .axes import ThreeAxes as ThreeAxes from .colors import ColormapDatabase as ColormapDatabase from .colors import ColorDatabase as ColorDatabase diff --git a/ultraplot/axes/__init__.py b/ultraplot/axes/__init__.py index 1c8163dcf..37effe7d8 100644 --- a/ultraplot/axes/__init__.py +++ b/ultraplot/axes/__init__.py @@ -17,6 +17,7 @@ from .plot import PlotAxes # noqa: F401 from .polar import PolarAxes from .shared import _SharedAxes # noqa: F401 +from .taylor import TaylorAxes from .three import ThreeAxes # noqa: F401 # Prevent importing module names and set order of appearance for objects @@ -25,6 +26,7 @@ "PlotAxes", "CartesianAxes", "PolarAxes", + "TaylorAxes", "GeoAxes", "ThreeAxes", "ExternalAxesContainer", @@ -34,7 +36,14 @@ # NOTE: We integrate with cartopy and basemap rather than using matplotlib's # native projection system. Therefore axes names are not part of public API. _cls_dict = {} # track valid names -for _cls in (CartesianAxes, PolarAxes, _CartopyAxes, _BasemapAxes, ThreeAxes): +for _cls in ( + CartesianAxes, + PolarAxes, + TaylorAxes, + _CartopyAxes, + _BasemapAxes, + ThreeAxes, +): for _name in (_cls._name, *_cls._name_aliases): with context._state_context(_cls, name="ultraplot_" + _name): mproj.register_projection(_cls) diff --git a/ultraplot/axes/taylor.py b/ultraplot/axes/taylor.py new file mode 100644 index 000000000..02f8c93cb --- /dev/null +++ b/ultraplot/axes/taylor.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +""" +Taylor diagram axes. +""" + +import inspect + +import matplotlib.projections.polar as mpolar +import matplotlib.ticker as mticker +import matplotlib.transforms as mtransforms +import numpy as np + +from ..config import rc +from ..internals import _not_none, _pop_rc, docstring +from .polar import PolarAxes + +__all__ = ["TaylorAxes"] + + +_format_docstring = """ +xlabel, ylabel : str, optional + Labels for the standard-deviation axes. These are drawn as Taylor-specific + text artists while the native polar axis labels are kept hidden. +corrlabel : str, default: 'Correlation' + Label for the correlation-coefficient grid. +thetaunit : {'corr', 'deg', 'rad'}, default: 'corr' + Units used for the angular grid labels. The default labels angular ticks + as correlation coefficients. +quadrant : {1, 2, 3, 4} or str, default: 1 + The quadrant used for the Taylor diagram. Quadrants follow the Cartesian + convention: ``1`` is upper right and ``4`` is lower right. +corrlocator, corrlines, corrticks : float or sequence of float, optional + Correlation coefficients used for the angular gridlines. +labelcolor, labelsize, labelweight : optional + Label text properties. +""" +docstring._snippet_manager["taylor.format"] = _format_docstring + + +class TaylorAxes(PolarAxes): + """ + Axes subclass for Taylor diagrams. + + Important + --------- + This axes subclass can be used by passing ``proj='taylor'`` to + axes-creation commands like `~ultraplot.figure.Figure.add_axes`, + `~ultraplot.figure.Figure.add_subplot`, and + `~ultraplot.figure.Figure.subplots`. + """ + + _name = "taylor" + _name_aliases = () + _default_corrs = np.array((1.0, 0.95, 0.9, 0.8, 0.6, 0.4, 0.2, 0.0)) + _quadrant_aliases = { + "1": 1, + "i": 1, + "ur": 1, + "upper right": 1, + "upright": 1, + "2": 2, + "ii": 2, + "ul": 2, + "upper left": 2, + "upleft": 2, + "3": 3, + "iii": 3, + "ll": 3, + "lower left": 3, + "lowleft": 3, + "4": 4, + "iv": 4, + "lr": 4, + "lower right": 4, + "lowright": 4, + "upside down": 4, + } + + @docstring._snippet_manager + def __init__(self, *args, **kwargs): + """ + Parameters + ---------- + *args + Passed to `matplotlib.axes.Axes`. + %(taylor.format)s + %(polar.format)s + + Other parameters + ---------------- + %(axes.format)s + %(rc.init)s + + See also + -------- + TaylorAxes.format + ultraplot.axes.PolarAxes + """ + self._taylor_corrs = self._default_corrs.copy() + self._taylor_thetaunit = "corr" + self._taylor_quadrant = 1 + self._taylor_labelpad = None + super().__init__(*args, **kwargs) + self._ensure_taylor_artists() + self._apply_taylor_defaults() + + @staticmethod + def correlation_to_angle(correlation): + """ + Convert correlation coefficients to Taylor-diagram polar angles. + """ + correlation = np.asarray(correlation) + return np.arccos(np.clip(correlation, -1, 1)) + + @classmethod + def _parse_quadrant(cls, quadrant): + """ + Normalize Taylor quadrant input. + """ + if quadrant is None: + return None + if isinstance(quadrant, str): + key = quadrant.strip().lower().replace("-", " ") + key = " ".join(key.split()) + quadrant = cls._quadrant_aliases.get(key) + if quadrant in (1, 2, 3, 4): + return int(quadrant) + raise ValueError( + "Invalid Taylor quadrant={!r}. Expected 1, 2, 3, 4, or an alias like " + "'upper right' or 'lower right'.".format(quadrant) + ) + + @staticmethod + def _quadrant_bounds(quadrant): + """ + Return theta bounds in degrees for a Taylor quadrant. + """ + return { + 1: (0, 90), + 2: (90, 180), + 3: (180, 270), + 4: (0, -90), + }[quadrant] + + def _correlation_to_theta(self, correlation): + """ + Convert correlation coefficients to displayed polar angles. + """ + angle = self.correlation_to_angle(correlation) + quadrant = self._taylor_quadrant + if quadrant == 1: + return angle + if quadrant == 2: + return np.pi / 2 + angle + if quadrant == 3: + return np.pi + angle + return -angle + + def plot_corr(self, correlation, stddev, *args, **kwargs): + """ + Plot values specified as correlation coefficient and standard deviation. + """ + return self.plot( + self._correlation_to_theta(correlation), stddev, *args, **kwargs + ) + + def scatter_corr(self, correlation, stddev, *args, **kwargs): + """ + Scatter values specified as correlation coefficient and standard deviation. + """ + return mpolar.PolarAxes.scatter( + self, self._correlation_to_theta(correlation), stddev, *args, **kwargs + ) + + def get_tightbbox(self, renderer, *args, **kwargs): + """ + Return a stable tight bbox before the first draw. + + Matplotlib's polar radial axis can report a spurious far-left bbox for + Taylor's quarter-sector view before the first draw. This feeds back into + UltraPlot's reference-width autosizing and creates excessive left margin. + """ + self._update_taylor_std_ticklabels() + bbox = super().get_tightbbox(renderer, *args, **kwargs.copy()) + axis_bbox = self.yaxis.get_tightbbox(renderer) + window = self.get_window_extent(renderer) + bogus_radial_bbox = ( + bbox is not None + and axis_bbox is not None + and axis_bbox.x0 < window.x0 - 0.25 * window.width + and axis_bbox.width > 0.5 * window.width + ) + if not bogus_radial_bbox: + return bbox + + visible = self.yaxis.get_visible() + try: + self.yaxis.set_visible(False) + bbox_no_yaxis = super().get_tightbbox(renderer, *args, **kwargs.copy()) + finally: + self.yaxis.set_visible(visible) + if bbox_no_yaxis is None: + return bbox + bbox = mtransforms.Bbox.from_extents( + bbox_no_yaxis.x0, + min(bbox.y0, bbox_no_yaxis.y0), + max(bbox_no_yaxis.x1, window.x1), + max(bbox.y1, bbox_no_yaxis.y1), + ) + self._tight_bbox = bbox + return bbox + + def set_xlabel(self, xlabel, fontdict=None, labelpad=None, **kwargs): + """ + Set the Taylor x label while keeping the native polar label hidden. + """ + text = super().set_xlabel( + xlabel, fontdict=fontdict, labelpad=labelpad, **kwargs + ) + self._ensure_taylor_artists() + self.xaxis.label.set_visible(False) + self._taylor_xlabel_artist.set_text(xlabel) + if labelpad is not None: + self._update_taylor_label_positions(labelpad) + if fontdict: + self._taylor_xlabel_artist.update(fontdict) + kwargs.pop("loc", None) + self._taylor_xlabel_artist.update(kwargs) + return text + + def set_ylabel(self, ylabel, fontdict=None, labelpad=None, **kwargs): + """ + Set the Taylor y label while keeping the native polar label hidden. + """ + text = super().set_ylabel( + ylabel, fontdict=fontdict, labelpad=labelpad, **kwargs + ) + self._ensure_taylor_artists() + self.yaxis.label.set_visible(False) + self._taylor_ylabel_artist.set_text(ylabel) + if labelpad is not None: + self._update_taylor_label_positions(labelpad) + if fontdict: + self._taylor_ylabel_artist.update(fontdict) + kwargs.pop("loc", None) + self._taylor_ylabel_artist.update(kwargs) + return text + + def _apply_taylor_defaults(self): + """ + Apply the fixed quarter-polar Taylor diagram defaults. + """ + thetamin, thetamax = self._quadrant_bounds(self._taylor_quadrant) + self.set_thetamin(thetamin) + self.set_thetamax(thetamax) + self.set_theta_zero_location("E") + self.set_theta_direction(1) + self.set_rorigin(0) + self.set_rlabel_position({1: 135, 2: 45, 3: 315, 4: 225}[self._taylor_quadrant]) + self.spines["polar"].set_visible(True) + self.xaxis.label.set_visible(False) + self.yaxis.label.set_visible(False) + self._update_taylor_ticks() + + def _ensure_taylor_artists(self): + """ + Create Taylor-specific label artists on demand. + """ + if hasattr(self, "_taylor_xlabel_artist"): + return + + kw = { + "transform": self.transAxes, + "clip_on": False, + "zorder": 3.5, + } + self._taylor_xlabel_artist = self.text( + 0.5, -0.12, "", ha="center", va="top", **kw + ) + self._taylor_ylabel_artist = self.text( + -0.12, + 0.5, + "", + ha="center", + va="bottom", + rotation=90, + rotation_mode="anchor", + **kw, + ) + self._taylor_corrlabel_artist = self.text( + 0.72, + 0.72, + "Correlation", + ha="center", + va="bottom", + rotation=-45, + rotation_mode="anchor", + **kw, + ) + self._taylor_yticklabel_artists = [] + + def _format_correlation(self, value): + """ + Format one angular tick according to the active Taylor theta unit. + """ + if self._taylor_thetaunit == "corr": + return f"{value:.2f}" + angle = np.arccos(np.clip(value, -1, 1)) + if self._taylor_thetaunit == "deg": + return f"{np.rad2deg(angle):g}\N{DEGREE SIGN}" + if self._taylor_thetaunit == "rad": + return f"{angle:g}" + raise ValueError( + "Invalid thetaunit={!r}. Expected 'corr', 'deg', or 'rad'.".format( + self._taylor_thetaunit + ) + ) + + def _update_taylor_label_positions(self, labelpad=None): + """ + Update fixed Taylor label locations. + """ + if labelpad is not None: + self._taylor_labelpad = labelpad + pad = _not_none(self._taylor_labelpad, rc["axes.labelpad"]) + try: + pad = float(pad) + except (TypeError, ValueError): + pad = float(rc["axes.labelpad"]) + offset = 0.09 + 0.004 * pad + quadrant = self._taylor_quadrant + + x_top = quadrant in (2, 3) + y_right = quadrant in (3, 4) + self._taylor_xlabel_artist.set_position((0.5, 1 + offset if x_top else -offset)) + self._taylor_xlabel_artist.set_verticalalignment("bottom" if x_top else "top") + self._taylor_ylabel_artist.set_position( + (1 + offset if y_right else -offset, 0.5) + ) + self._taylor_ylabel_artist.set_horizontalalignment( + "left" if y_right else "center" + ) + self._taylor_ylabel_artist.set_verticalalignment( + "center" if y_right else "bottom" + ) + self._taylor_ylabel_artist.set_rotation(270 if y_right else 90) + + corr_positions = { + 1: (np.deg2rad(45), -45), + 2: (np.deg2rad(135), 45), + 3: (np.deg2rad(225), -45), + 4: (np.deg2rad(-45), 45), + } + theta, rotation = corr_positions[quadrant] + _, rmax = self.get_ylim() + radius = rmax + 0.22 * abs(rmax) + self._taylor_corrlabel_artist.set_transform(self.transData) + self._taylor_corrlabel_artist.set_position((theta, radius)) + self._taylor_corrlabel_artist.set_rotation(rotation) + self._taylor_corrlabel_artist.set_horizontalalignment("center") + self._taylor_corrlabel_artist.set_verticalalignment("bottom") + + def _update_taylor_labels( + self, + *, + xlabel=None, + ylabel=None, + corrlabel=None, + labelpad=None, + labelcolor=None, + labelsize=None, + labelweight=None, + xlabel_kw=None, + ylabel_kw=None, + corrlabel_kw=None, + ): + """ + Update Taylor-specific axis labels. + """ + self._ensure_taylor_artists() + xlabel_kw = xlabel_kw or {} + ylabel_kw = ylabel_kw or {} + corrlabel_kw = corrlabel_kw or {} + props = rc._get_label_props( + color=labelcolor, + size=labelsize, + weight=labelweight, + labelpad=labelpad, + ) + labelpad = props.pop("labelpad", None) + self._update_taylor_label_positions(labelpad) + + if xlabel is not None: + self.xaxis.set_label_text(xlabel) + self.xaxis.label.set_visible(False) + self._taylor_xlabel_artist.set_text(xlabel) + if ylabel is not None: + self.yaxis.set_label_text(ylabel) + self.yaxis.label.set_visible(False) + self._taylor_ylabel_artist.set_text(ylabel) + if corrlabel is not None: + self._taylor_corrlabel_artist.set_text(corrlabel) + + for artist, kw in ( + (self._taylor_xlabel_artist, xlabel_kw), + (self._taylor_ylabel_artist, ylabel_kw), + (self._taylor_corrlabel_artist, corrlabel_kw), + ): + artist.update(props) + artist.update(kw) + + def _update_taylor_ticks(self, corrs=None): + """ + Update angular grid labels from correlation coefficients. + """ + if corrs is not None: + corrs = np.asarray(corrs, dtype=float) + if corrs.ndim == 0: + step = float(corrs) + if step <= 0: + raise ValueError("Taylor correlation tick step must be positive.") + corrs = np.arange(1, -0.5 * step, -step) + corrs = np.clip(corrs, 0, 1) + self._taylor_corrs = corrs + corrs = np.asarray(self._taylor_corrs, dtype=float) + if np.any((corrs < -1) | (corrs > 1)): + raise ValueError("Taylor correlation ticks must be between -1 and 1.") + angles = self._correlation_to_theta(corrs) + labels = [self._format_correlation(corr) for corr in corrs] + self.xaxis.set_major_locator(mticker.FixedLocator(angles)) + self.xaxis.set_major_formatter(mticker.FixedFormatter(labels)) + + def _update_taylor_std_ticklabels(self): + """ + Duplicate radial tick labels onto the vertical standard-deviation axis. + """ + if not hasattr(self, "_taylor_yticklabel_artists"): + return + rmin, rmax = self.get_ylim() + if not np.isfinite(rmin) or not np.isfinite(rmax) or np.isclose(rmin, rmax): + return + + ticks = np.asarray(self.get_yticks(), dtype=float) + mask = (ticks >= min(rmin, rmax)) & (ticks <= max(rmin, rmax)) + mask &= ~np.isclose(ticks, rmin) + ticks = ticks[mask] + formatter = self.yaxis.get_major_formatter() + try: + labels = formatter.format_ticks(ticks) + except Exception: + labels = [formatter(tick, index) for index, tick in enumerate(ticks)] + + quadrant = self._taylor_quadrant + if quadrant in (1, 2): + theta = np.pi / 2 + else: + theta = -np.pi / 2 + ha = "right" if quadrant in (1, 3) else "left" + dx = -3 if ha == "right" else 3 + transform = self.transData + mtransforms.ScaledTranslation( + dx / 72, 0, self.figure.dpi_scale_trans + ) + + for index, (tick, label) in enumerate(zip(ticks, labels)): + if index >= len(self._taylor_yticklabel_artists): + artist = self.text( + theta, + tick, + "", + transform=transform, + ha=ha, + va="center", + clip_on=False, + zorder=3.5, + ) + self._taylor_yticklabel_artists.append(artist) + artist = self._taylor_yticklabel_artists[index] + artist.set_text(label) + artist.set_position((theta, tick)) + artist.set_transform(transform) + artist.set_horizontalalignment(ha) + artist.set_verticalalignment("center") + artist.set_visible(bool(label)) + for artist in self._taylor_yticklabel_artists[len(ticks) :]: + artist.set_visible(False) + + def draw(self, renderer=None, *args, **kwargs): + """ + Draw after refreshing Taylor-specific standard-deviation tick labels. + """ + self._update_taylor_std_ticklabels() + super().draw(renderer, *args, **kwargs) + + @docstring._snippet_manager + def format( + self, + *, + xlabel=None, + ylabel=None, + corrlabel=None, + thetaunit=None, + quadrant=None, + corrlocator=None, + corrlines=None, + corrticks=None, + xlabel_kw=None, + ylabel_kw=None, + corrlabel_kw=None, + labelpad=None, + labelcolor=None, + labelsize=None, + labelweight=None, + **kwargs, + ): + """ + Modify Taylor diagram labels, correlation gridlines, and polar settings. + + Parameters + ---------- + %(taylor.format)s + + Other parameters + ---------------- + %(polar.format)s + %(axes.format)s + %(figure.format)s + %(rc.format)s + + See also + -------- + ultraplot.axes.PolarAxes.format + ultraplot.axes.Axes.format + """ + rc_kw, rc_mode = _pop_rc(kwargs) + with rc.context(rc_kw, mode=rc_mode): + self._ensure_taylor_artists() + quadrant = self._parse_quadrant(quadrant) + if quadrant is not None: + self._taylor_quadrant = quadrant + self._apply_taylor_defaults() + if thetaunit is not None: + thetaunit = thetaunit.lower() + if thetaunit not in ("corr", "deg", "rad"): + raise ValueError( + "Invalid thetaunit={!r}. Expected 'corr', 'deg', or 'rad'.".format( + thetaunit + ) + ) + self._taylor_thetaunit = thetaunit + corrs = _not_none( + corrlocator=corrlocator, corrlines=corrlines, corrticks=corrticks + ) + self._update_taylor_ticks(corrs) + self._update_taylor_labels( + xlabel=xlabel, + ylabel=ylabel, + corrlabel=corrlabel, + labelpad=labelpad, + labelcolor=labelcolor, + labelsize=labelsize, + labelweight=labelweight, + xlabel_kw=xlabel_kw, + ylabel_kw=ylabel_kw, + corrlabel_kw=corrlabel_kw, + ) + + super().format( + rc_kw=rc_kw, + rc_mode=rc_mode, + labelpad=labelpad, + labelcolor=labelcolor, + labelsize=labelsize, + labelweight=labelweight, + **kwargs, + ) + self.xaxis.label.set_visible(False) + self.yaxis.label.set_visible(False) + self._update_taylor_label_positions() + self._update_taylor_std_ticklabels() + + +TaylorAxes._format_signatures[TaylorAxes] = inspect.signature(TaylorAxes.format) +TaylorAxes.format = docstring._obfuscate_kwargs(TaylorAxes.format) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 028ed3d13..5493755ab 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1544,6 +1544,16 @@ def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): subplot_types = set() unsupported_found = False sides = ("top", "bottom") if axis == "x" else ("left", "right") + main_axes = [axi for axi in group_axes if not getattr(axi, "_panel_side", None)] + if len(main_axes) < 2: + supported = all( + isinstance( + axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) + ) + for axi in main_axes + ) + if not supported: + return {}, True for axi in group_axes: # Only main axes "vote" diff --git a/ultraplot/tests/test_projections.py b/ultraplot/tests/test_projections.py index b32d1e289..1c5a5c54a 100644 --- a/ultraplot/tests/test_projections.py +++ b/ultraplot/tests/test_projections.py @@ -7,6 +7,7 @@ import cartopy.crs as ccrs import matplotlib.pyplot as plt +import matplotlib.ticker as mticker import numpy as np import pytest @@ -157,6 +158,184 @@ def test_polar_projections(): return fig +def test_taylor_projection_labels_and_defaults(): + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + + assert ax._name == "taylor" + ax.format(xlabel="STD X", ylabel="STD Y") + fig.canvas.draw() + + assert ax.get_xlabel() == "STD X" + assert ax.get_ylabel() == "STD Y" + assert not ax.xaxis.label.get_visible() + assert not ax.yaxis.label.get_visible() + assert ax._taylor_xlabel_artist.get_text() == "STD X" + assert ax._taylor_ylabel_artist.get_text() == "STD Y" + assert ax._taylor_corrlabel_artist.get_text() == "Correlation" + assert np.allclose(np.rad2deg(ax.get_xlim()), (0.0, 90.0)) + assert ax.get_rlabel_position() == pytest.approx(135.0) + assert [label.get_text() for label in ax.get_xticklabels()] == [ + "1.00", + "0.95", + "0.90", + "0.80", + "0.60", + "0.40", + "0.20", + "0.00", + ] + + +def test_taylor_projection_thetaunit_deg(): + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + ax.format(thetaunit="deg") + fig.canvas.draw() + + labels = [label.get_text() for label in ax.get_xticklabels() if label.get_text()] + assert labels + assert any("°" in label for label in labels) + + +def test_taylor_projection_quadrants_and_corr_helpers(): + fig, axs = uplt.subplots(ncols=3, proj="taylor") + for ax, quadrant, expected_xlim, expected_rlabel, expected_theta in zip( + axs, + ("upper-left", 3, "upside down"), + ((90.0, 180.0), (180.0, 270.0), (0.0, -90.0)), + (45.0, 315.0, 225.0), + (np.pi / 2 + np.pi / 3, np.pi + np.pi / 3, -np.pi / 3), + ): + ax.format( + quadrant=quadrant, + corrlabel="rho", + thetaunit="rad", + corrticks=[1.0, 0.5, 0.0], + labelpad=8, + labelcolor="red", + labelsize=9, + labelweight="bold", + xlabel_kw={"color": "blue"}, + ylabel_kw={"color": "green"}, + corrlabel_kw={"color": "purple"}, + ) + line = ax.plot_corr([0.5], [1.2], marker="o")[0] + points = ax.scatter_corr([0.5], [1.2]) + + assert np.allclose(np.rad2deg(ax.get_xlim()), expected_xlim) + assert ax.get_rlabel_position() == pytest.approx(expected_rlabel) + assert line.get_xdata()[0] == pytest.approx(expected_theta) + assert points.get_offsets()[0, 0] == pytest.approx(expected_theta) + assert ax._taylor_corrlabel_artist.get_text() == "rho" + assert ax._taylor_xlabel_artist.get_color() == "blue" + assert ax._taylor_ylabel_artist.get_color() == "green" + assert ax._taylor_corrlabel_artist.get_color() == "purple" + assert [label.get_text() for label in ax.get_xticklabels()] == [ + "0", + "1.0472", + "1.5708", + ] + + +def test_taylor_projection_setters_and_scalar_corrticks(): + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + + returned_xlabel = ax.set_xlabel( + "direct x", fontdict={"size": 11}, labelpad=6, loc="right", color="red" + ) + returned_ylabel = ax.set_ylabel( + "direct y", fontdict={"size": 12}, labelpad=7, loc="top", color="blue" + ) + ax.format(corrlocator=0.5) + fig.canvas.draw() + + assert returned_xlabel is ax.xaxis.label + assert returned_ylabel is ax.yaxis.label + assert ax._taylor_xlabel_artist.get_text() == "direct x" + assert ax._taylor_ylabel_artist.get_text() == "direct y" + assert ax._taylor_xlabel_artist.get_color() == "red" + assert ax._taylor_ylabel_artist.get_color() == "blue" + assert ax._taylor_labelpad == 7 + assert np.allclose(ax._taylor_corrs, [1.0, 0.5, 0.0]) + + +def test_taylor_projection_std_ticklabels_update_and_hide(): + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + ax.format(quadrant=4, rlim=(0, 2), rlines=[0, 1, 2]) + fig.canvas.draw() + + artists = ax._taylor_yticklabel_artists + visible_artists = [artist for artist in artists if artist.get_visible()] + assert visible_artists + assert all( + artist.get_position()[0] == pytest.approx(-np.pi / 2) + for artist in visible_artists + ) + assert all(artist.get_horizontalalignment() == "left" for artist in visible_artists) + + ax.set_yticks([0, 1]) + ax._update_taylor_std_ticklabels() + assert artists[0].get_text() == "1" + assert all(not artist.get_visible() for artist in artists[1:]) + + +def test_taylor_projection_std_ticklabels_formatter_fallback(): + class RaisingFormatter(mticker.Formatter): + def format_ticks(self, values): + raise RuntimeError("force scalar formatting") + + def __call__(self, value, pos=None): + return f"tick-{pos}:{value:g}" + + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + ax.format(rlim=(0, 2), rlines=[0, 1]) + ax.yaxis.set_major_formatter(RaisingFormatter()) + ax._update_taylor_std_ticklabels() + + assert ax._taylor_yticklabel_artists[0].get_text() == "tick-0:1" + + +def test_taylor_projection_validation_errors(): + fig, axs = uplt.subplots(proj="taylor") + ax = axs[0] + + assert ax._parse_quadrant(None) is None + assert np.allclose(ax.correlation_to_angle([-2, 0, 2]), [np.pi, np.pi / 2, 0]) + with pytest.raises(ValueError, match="Invalid Taylor quadrant"): + ax.format(quadrant="sideways") + with pytest.raises(ValueError, match="Invalid thetaunit"): + ax.format(thetaunit="turns") + with pytest.raises(ValueError, match="tick step must be positive"): + ax.format(corrlines=0) + with pytest.raises(ValueError, match="between -1 and 1"): + ax.format(corrticks=[1.2]) + ax._taylor_thetaunit = "turns" + with pytest.raises(ValueError, match="Invalid thetaunit"): + ax._format_correlation(0.5) + + +def test_taylor_single_axes_skips_shared_ticklabel_baseline(): + fig, axs = uplt.subplots(proj="taylor") + baseline, skip = fig._compute_baseline_tick_state( + [axs[0]], "x", ("labelbottom", "labeltop") + ) + + assert baseline == {} + assert skip + + +def test_taylor_projection_via_figure_format_dispatch(): + fig, axs = uplt.subplots(ncols=2, proj="taylor") + axs.format(xlabel="Common X", ylabel="Common Y") + for ax in axs: + assert ax.get_xlabel() == "Common X" + assert ax.get_ylabel() == "Common Y" + + def test_polar_format_thetalabel_rlabel(): """ `thetalabel` and `rlabel` both create CurvedText artists.