Source code for fips.visualization

"""
Visualization and plotting utilities for inverse problem results.

This module provides functions for plotting error norms, comparing multiple series,
and computing credible intervals for uncertainty visualization.
"""

from collections.abc import Iterable, Sequence

import numpy as np
import pandas as pd

from fips.covariance import CovarianceMatrix
from fips.matrix import Matrix
from fips.vector import Block, Vector

ArrayLike = Sequence | np.ndarray


def _require_matplotlib():
    try:
        import matplotlib.pyplot as plt  # type: ignore

        return plt
    except ImportError as exc:
        raise ImportError(
            "matplotlib is required for plotting; install with `pip install matplotlib`."
        ) from exc


[docs] def plot_error_norm( prior: ArrayLike | Vector | Block, posterior: ArrayLike | Vector | Block, truth: ArrayLike | Vector | Block, t: Iterable[float] | None = None, norm: str = "l2", figsize: tuple[float, float] | None = None, ): """ Plot normed errors of prior and posterior against truth. Parameters ---------- prior : ArrayLike or Vector or Block Prior estimates. posterior : ArrayLike or Vector or Block Posterior estimates. truth : ArrayLike or Vector or Block True values for comparison. t : Iterable[float], optional Time or x-axis values. If None, uses integer indices. norm : {'l2', 'l1', 'linf'}, default 'l2' Norm type to use for error calculation. figsize : tuple[float, float], optional Figure size (width, height) in inches. Returns ------- fig : matplotlib.figure.Figure The figure object. ax : matplotlib.axes.Axes The axes object. """ plt = _require_matplotlib() # Safely extract raw values and force to 2D for consistent norm calculation def _format(x): arr = np.asarray(getattr(x, "values", x)) return arr[:, None] if arr.ndim == 1 else arr p_arr, post_arr, t_arr = _format(prior), _format(posterior), _format(truth) if norm == "l2": prior_err = np.linalg.norm(p_arr - t_arr, axis=1) post_err = np.linalg.norm(post_arr - t_arr, axis=1) elif norm == "l1": prior_err = np.sum(np.abs(p_arr - t_arr), axis=1) post_err = np.sum(np.abs(post_arr - t_arr), axis=1) elif norm == "linf": prior_err = np.max(np.abs(p_arr - t_arr), axis=1) post_err = np.max(np.abs(post_arr - t_arr), axis=1) else: raise ValueError("norm must be one of {'l2', 'l1', 'linf'}") x = np.arange(len(prior_err)) if t is None else np.asarray(list(t)) fig, ax = plt.subplots(figsize=figsize) ax.plot(x, prior_err, label=f"prior |.|{norm}", linestyle="--", color="tab:blue") ax.plot( x, post_err, label=f"posterior |.|{norm}", linestyle="-", color="tab:orange" ) ax.set(xlabel="step", ylabel=f"{norm} error") ax.legend(loc="best") fig.tight_layout() return fig, ax
[docs] def plot_comparison( *series: pd.Series | Vector | Block, x: str | int | None = None, truth: pd.Series | Vector | Block | None = None, errors: Sequence[pd.DataFrame | CovarianceMatrix | Matrix | pd.Series | None] | None = None, kind: str | None = None, ): """ Compare multiple aligned series with optional errors and truth values. Plots multiple series (e.g., prior, posterior, observations) on the same axes with optional error bars/bands and truth values for comparison. Parameters ---------- *series : pd.Series or Vector or Block Variable number of series to compare. x : str or int, optional Name or level of the index to use for x-axis. Required if series have MultiIndex. truth : pd.Series or Vector or Block, optional True values to plot as reference. errors : Sequence of pd.DataFrame or CovarianceMatrix or Matrix or pd.Series or None, optional Error estimates for each series. Must match length of series. If CovarianceMatrix or Matrix, extracts diagonal as standard deviations. kind : {'line', 'bar'}, optional Plot type. If None, auto-detects based on x-axis data type. Returns ------- fig : matplotlib.figure.Figure The figure object. ax : matplotlib.axes.Axes The axes object. """ plt = _require_matplotlib() # Unpack data df = pd.concat([getattr(s, "data", s) for s in series], axis=1) # Extract standard deviations for error bounds err_list = [] if errors: for i, (s, err) in enumerate(zip(series, errors, strict=False)): if err is None: err_list.append(None) continue name = getattr(s, "name", df.columns[i]) if isinstance(err, CovarianceMatrix): e = pd.Series(np.sqrt(err.variances.values), index=err.index, name=name) elif isinstance(err, Matrix): e = pd.Series(np.sqrt(np.diag(err.values)), index=err.index, name=name) else: e = getattr(err, "data", err) e.name = name err_list.append(e) err_df = pd.concat(err_list, axis=1) if err_list else None # Determine X-axis values if isinstance(df.index, pd.MultiIndex): if x is None: raise ValueError("x must be specified when series have a MultiIndex") x_vals = df.index.get_level_values(x) else: x_vals = df.index x = x or df.index.name or "index" # Auto-detect plot kind if kind is None: if pd.api.types.is_numeric_dtype( x_vals ) or pd.api.types.is_datetime64_any_dtype(x_vals): kind = "line" else: kind = "bar" # Plot fig, ax = plt.subplots() colors = plt.cm.tab10(np.linspace(0, 1, len(series))) x_pos = np.arange(len(x_vals)) # Used for bar chart offsets for i, col in enumerate(df.columns): y = df[col] e = err_df[col] if err_df is not None else None if kind == "line": ax.plot(x_vals, y, label=str(col), color=colors[i]) if e is not None: ax.fill_between(x_vals, y - e, y + e, color=colors[i], alpha=0.3) else: ax.bar( x_pos + i * 0.2, y, width=0.2, label=str(col), color=colors[i], yerr=e, capsize=5, ) if truth is not None: t_data = getattr(truth, "data", truth) if kind == "line": ax.plot(x_vals, t_data, label="truth", color="black", linestyle="--") else: ax.scatter( x_pos + (len(series) / 2 - 0.5) * 0.2, t_data, label="truth", color="black", marker="x", s=100, zorder=5, ) if kind == "bar": ax.set_xticks(x_pos + (len(series) / 2 - 0.5) * 0.2) ax.set_xticklabels(x_vals) ax.legend() ax.set(xlabel=str(x), ylabel="value") return fig, ax
[docs] def compute_credible_interval( samples: ArrayLike, q: tuple[float, float] = (0.05, 0.95) ) -> tuple[np.ndarray, np.ndarray]: """ Compute lower and upper quantiles along the first axis of samples. Parameters ---------- samples : ArrayLike Array of samples with shape (n_samples, ...). Quantiles computed along axis 0. q : tuple[float, float], default (0.05, 0.95) Lower and upper quantile values in [0, 1]. Returns ------- lower : np.ndarray Lower quantile values. upper : np.ndarray Upper quantile values. """ samples = np.asarray(samples) return np.quantile(samples, q[0], axis=0), np.quantile(samples, q[1], axis=0)