diff --git a/CHANGELOG.md b/CHANGELOG.md index b28a550b0c60a5033cd19dc16151a1c963500ede..fd84ea91728cb20e983fe5daaec47a97e35ac5d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ SPDX-License-Identifier: GPL-3.0-or-later [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.2.1...develop) ### Added - add option to not overwrite existing flags to `concatFlags` +- add option to pass existing axis object to `plot` ### Changed ### Removed ### Fixed diff --git a/saqc/funcs/tools.py b/saqc/funcs/tools.py index 0967a823a62b5266f6a35dfd89aaf698a2cf0c5f..87ad8ec3e7ba078f9678fe17a639df9420e2a736 100644 --- a/saqc/funcs/tools.py +++ b/saqc/funcs/tools.py @@ -234,6 +234,7 @@ class ToolsMixin: xscope: Optional[slice] = None, phaseplot: Optional[str] = None, store_kwargs: Optional[dict] = None, + ax: mpl.axes.Axes | None = None, ax_kwargs: Optional[dict] = None, dfilter: float = FILTER_NONE, **kwargs, @@ -297,7 +298,6 @@ class ToolsMixin: """ data, flags = self._data.copy(), self._flags.copy() - interactive = path is None level = kwargs.get("flag", UNFLAGGED) if dfilter < np.inf: @@ -309,9 +309,8 @@ class ToolsMixin: if ax_kwargs is None: ax_kwargs = {} - if interactive: + if not path: mpl.use(_MPL_DEFAULT_BACKEND) - else: mpl.use("Agg") @@ -324,13 +323,14 @@ class ToolsMixin: history=history, xscope=xscope, phaseplot=phaseplot, + ax=ax, ax_kwargs=ax_kwargs, ) - if interactive: + if ax is None: plt.show() - else: + if path: if store_kwargs.pop("pickle", False): with open(path, "wb") as f: pickle.dump(fig, f) diff --git a/saqc/lib/plotting.py b/saqc/lib/plotting.py index edd2df9c2d6698ea3788ce966d7ddd2e16425444..d3f20acb791cac6a7c78d87854a42066c71af37b 100644 --- a/saqc/lib/plotting.py +++ b/saqc/lib/plotting.py @@ -6,6 +6,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import itertools from typing import Optional, Union @@ -58,11 +60,12 @@ def makeFig( field: str, flags: Flags, level: float, - max_gap: Optional[str] = None, - history: Union[Optional[Literal["valid", "complete"]], list] = "valid", - xscope: Optional[slice] = None, - phaseplot: Optional[str] = None, - ax_kwargs: Optional[dict] = None, + max_gap: str | None = None, + history: Literal["valid", "complete"] | None | list[str] = "valid", + xscope: slice | None = None, + phaseplot: str | None = None, + ax: mpl.axes.Axes | None = None, + ax_kwargs: dict | None = None, ): """ Returns a figure object, containing data graph with flag marks for field. @@ -152,9 +155,10 @@ def makeFig( d = _insertBlockingNaNs(d, max_gap) # figure composition - fig = mpl.pyplot.figure(constrained_layout=True, **FIG_KWARGS) - grid = fig.add_gridspec() - ax = fig.add_subplot(grid[0]) + if ax is None: + fig = mpl.pyplot.figure(constrained_layout=True, **FIG_KWARGS) + grid = fig.add_gridspec() + ax = fig.add_subplot(grid[0]) _plotVarWithFlags( ax, @@ -172,7 +176,7 @@ def makeFig( ) plt.rcParams["font.size"] = default - return fig + return ax.figure def _plotVarWithFlags(