Source code for skultrafast.dataset

import functools
import typing
import warnings
from collections import namedtuple
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Type, Union, cast

import attr
import lmfit
import matplotlib.pyplot as plt
import numpy as np
from lmfit.minimizer import MinimizerResult
from matplotlib.lines import Line2D
from scipy.interpolate import UnivariateSpline, interp1d

import skultrafast.dv as dv
import skultrafast.plot_helpers as ph
from skultrafast import filter, fitter, lifetimemap, zero_finding
from skultrafast.data_io import save_txt
from skultrafast.kinetic_model import Model
from skultrafast.utils import linreg_std_errors, sigma_clip

[docs] ndarray: Type[np.ndarray] = np.ndarray
[docs] EstDispResult = namedtuple("EstDispResult", "correct_ds tn polynomial")
EstDispResult.__doc__ = """ Tuple containing the results from an dispersion estimation. Attributes ---------- correct_ds : TimeResSpec A dataset were we used linear interpolation to remove the dispersion. tn : array Array containing the results of the applied heuristic. polynomial : function Function which maps wavenumbers to time-zeros. """ # FitExpResult = namedtuple("FitExpResult", "lmfit_mini lmfit_res fitter") @attr.s(auto_attribs=True)
[docs] class FitExpResult:
[docs] lmfit_mini: lmfit.Minimizer
[docs] lmfit_res: MinimizerResult
[docs] fitter: fitter.Fitter
[docs] pol_resolved: bool = False
[docs] std_errs: Optional[np.ndarray] = None
[docs] var: Optional[np.ndarray] = None
[docs] r2: Optional[np.ndarray] = None
[docs] def calculate_stats(self): f = self.fitter std_errs, vars, r2s = linreg_std_errors(f.x_vec, f.data) self.std_errs = std_errs self.vars = vars self.r2s = r2s
[docs] def make_sas(self, model: Model, QYs: Dict[str, float] = {}, y0: Optional[np.ndarray] = None): """ Generate the species associated spectra from a given model using the current das. Parameters ---------- model : Model Model describing the kinetics. The number of transition rates should be identical to the number of DAS-rates. Currtenly, the function assumes that the transitions are added in a sorted way, e.g. fastest rates first. QYs : dict Values for the yields. y0 : ndarray Starting concentrations. If none, y0 = [1, 0, 0, ...]. Returns ------- ndarry """ f = self.fitter taus = f.last_para[-f.num_exponentials:] idx = np.argsort(taus) taus = taus[idx] kvals = 1 / taus if y0 is None: y0 = np.zeros(len(taus)) y0[0] = 1 func = model.build_mat_func() K = func(*kvals, **QYs) vals, vecs = np.linalg.eig(K) if np.any(np.subtract.outer(vals, vals) > 1e10): raise ValueError("Multivalued eigenvalue") vecs = vecs[:, ::-1] A = (vecs @ np.diag(np.linalg.solve(vecs, y0))).T sas = np.linalg.solve(A, f.c[:, idx].T) ct = f.x_vec[:, :f.num_exponentials] @ A return sas, ct
@attr.s(eq=False)
[docs] class LDMResult:
[docs] skmodel: object = attr.ib()
[docs] coefs: np.ndarray = attr.ib()
[docs] fit: np.ndarray = attr.ib()
[docs] alpha: np.ndarray = attr.ib()
[docs] class TimeResSpec: def __init__( self, wl, t, data, err=None, name=None, freq_unit="nm", disp_freq_unit=None, auto_plot=True, ): """ Class for working with time-resolved spectra. If offers methods for analyzing and pre-processing the data. To visualize the data, each `TimeResSpec` object has an instance of an `DataSetPlotter` object accessible under `plot`. Parameters ---------- wl : array of shape(n) Array of the spectral dimension t : array of shape(m) Array with the delay times. data : array of shape(n, m) Array with the data for each point. err : array of shape(n, m) or None (optional) Contains the std err of the data, can be `None`. name : str (optional) Identifier for data set. freq_unit : 'nm' or 'cm' (optional) Unit of the wavelength array, default is 'nm'. disp_freq_unit : 'nm','cm' or None (optional) Unit which is used by default for plotting, masking and cutting the dataset. If `None`, it defaults to `freq_unit`. Attributes ---------- wavelengths, wavenumbers, t, data : ndarray Arrays with the data itself. plot : TimeResSpecPlotter Helper class which can plot the dataset using `matplotlib`. t_idx : function Helper function to find the nearest index in t for a given time. wl_idx : function Helper function to search for the nearest wavelength index for a given wavelength. wn_idx : function Helper function to search for the nearest wavelength index for a given wavelength. auto_plot : bool When True, some function will display their result automatically. """ correct_shape = (t.shape[0], wl.shape[0]) == data.shape assert correct_shape, f"Data shapes do not match: {t.shape}, {wl.shape} != {data.shape}" t = t.copy() wl = wl.copy() data = data.copy() if freq_unit == "nm": self._wavelengths = wl self._wavenumbers = 1e7 / wl else: self._wavelengths = 1e7 / wl self._wavenumbers = wl self.wn = self.wavenumbers self.wl = self.wavelengths self.t = t self.data = data self.err = err if name is not None: self.name = name # Sort wavelenths and data. idx = np.argsort(self._wavelengths) self._wavelengths = self._wavelengths[idx] self._wavenumbers = self._wavenumbers[idx] self.data = self.data[:, idx] if err is not None: self.err = self.err[:, idx] self.auto_plot = auto_plot self.plot = TimeResSpecPlotter(self) self.t_idx = lambda x: dv.fi(self.t, x) self.wl_idx = lambda x: dv.fi(self.wavelengths, x) self.wn_idx = lambda x: dv.fi(self.wavenumbers, x) if disp_freq_unit is None: self.disp_freq_unit = freq_unit else: self.disp_freq_unit = disp_freq_unit self.plot.freq_unit = self.disp_freq_unit self.trans = self.plot.trans self.spec = self.plot.spec self.map = self.plot.map @property
[docs] def wavelengths(self): return self._wavelengths
@wavelengths.setter def wavelengths(self, wavelengths): self._wavelengths = wavelengths self._wavenumbers = 1e7 / wavelengths @property
[docs] def wavenumbers(self): return self._wavenumbers
@wavenumbers.setter def wavenumbers(self, wavenumbers): self._wavelengths = 1e7 / wavenumbers self._wavenumbers = wavenumbers
[docs] def __iter__(self): """For compatibility with dv.tup""" return iter((self.wavelengths, self.t, self.data))
[docs] def wl_d(self, wl: float): """ Returns the nearest transient for given wavelength. """ idx = self.wl_idx(wl) return self.data[:, idx]
[docs] def wn_i(self, wn1, wn2, method='trapz'): """ Integrates the signal from wn1 to wn2 Parameters ---------- wn1, float Wavenumber of the first edge wn2, float Wavenumber of the second edge method, ('trapz', 'spline') Method used to integrate. """ wn_min, wn_max = sorted([wn1, wn2]) idx_min, idx_max = self.wn_idx(wn_min), self.wn_idx(wn_max) if (wn_max > self.wavenumbers.max()): idx_max = None if (wn_min < self.wavenumbers.min()): idx_min = None sl = slice(idx_min, idx_max) x = self.wavenumbers[sl] y = self.data[:, sl] if method == 'trapz': return np.trapz(x=x, y=y, axis=1) elif method == 'spline': sp = UnivariateSpline(x, y) return sp.antiderivative(1)(x)
[docs] def wn_d(self, wn: float): """ Returns the nearest transient for given wavenumber. """ idx = self.wn_idx(wn) return self.data[:, idx]
[docs] def t_d(self, t): """ Returns the nearest spectrum for given delaytime. """ idx = self.t_idx(t) return self.data[idx, :]
[docs] def copy(self) -> "TimeResSpec": """Returns a copy of the TimeResSpec.""" return TimeResSpec( self.wavelengths, self.t, self.data, disp_freq_unit=self.disp_freq_unit, err=self.err, auto_plot=self.auto_plot, )
@classmethod
[docs] def from_txt(cls, fname, freq_unit="nm", time_div=1.0, transpose=False, disp_freq_unit=None, loadtxt_kws=None): """ Directly create a dataset from a text file. Parameters ---------- fname : str Name of the file. This function assumes the data is given by a (n+1, m+1) table. Excludig the [0, 0] value, the first row gives the frequencies and the first column gives the delay-times. freq_unit : {'nm', 'cm'} Unit of the frequencies. time_div : float Since `skultrafast` prefers to work with picoseconds and programs may use different units, it divides the time-values by `time_div`. Use `1`, the default, to not change the time values. transpose : bool Transposes the loaded array. disp_freq_unit : Optional[str] See class documentation. loadtxt_kws : dict Dict containing keyword arguments to `np.loadtxt`. """ if loadtxt_kws is None: loadtxt_kws = {} tmp = np.loadtxt(fname, **loadtxt_kws) if transpose: tmp = tmp.T t = tmp[1:, 0] / time_div freq = tmp[0, 1:] data = tmp[1:, 1:] return cls(freq, t, data, freq_unit=freq_unit, disp_freq_unit=disp_freq_unit)
[docs] def save_txt(self, fname, freq_unit="wl"): """ Saves the dataset as a text file. Parameters ---------- fname : str Filename (can include path) freq_unit : 'nm' or 'cm' (default 'nm') Which frequency unit is used. """ wl = self.wavelengths if freq_unit == "wl" else self.wavenumbers save_txt(fname, wl, self.t, self.data) if self.err is not None: save_txt(str(fname) + '.stderr', wl, self.t, self.err)
[docs] def cut_freq(self, lower=-np.inf, upper=np.inf, invert_sel=False, freq_unit=None) -> "TimeResSpec": """ Removes channels inside (or outside ) of given frequency ranges. Parameters ---------- lower : float Lower bound of the region upper : float Upper bound of the region invert_sel : bool Invert the final selection. freq_unit : 'nm', 'cm' or None Unit of the given edges. Returns ------- : TimeResSpec TimeResSpec containing only the listed regions. """ if freq_unit is None: freq_unit = self.disp_freq_unit arr = self.wavelengths if freq_unit == "nm" else self.wavenumbers idx = np.logical_and(lower <= arr, arr < upper) if not invert_sel: idx = ~idx if self.err is not None: err = self.err[:, idx] else: err = None return TimeResSpec( self.wavelengths[idx], self.t, self.data[:, idx], err, "nm", disp_freq_unit=self.disp_freq_unit, )
[docs] def mask_freq_idx(self, idx): """Masks given freq idx array Parameters ---------- idx : array Boolean array, same shape as the freqs. Where it is `True`, the freqs will be masked. """ if self.err is not None: self.err = np.ma.MaskedArray(self.err) self.err[:, idx] = np.ma.masked self.data = np.ma.MaskedArray(self.data) self.data[:, idx] = np.ma.masked
[docs] def mask_freqs(self, freq_ranges, invert_sel=False, freq_unit=None): """ Mask channels inside of given frequency ranges. Parameters ---------- freq_ranges : list of (float, float) List containing the edges (lower, upper) of the frequencies to keep. invert_sel : bool When True, it inverts the selection. Can be used mark everything outside selected ranges. freq_unit : 'nm', 'cm' or None Unit of the given edges. Returns ------- : None """ idx = np.zeros_like(self.wavelengths, dtype=bool) if freq_unit is None: freq_unit = self.disp_freq_unit arr = self.wavelengths if freq_unit == "nm" else self.wavenumbers for (lower, upper) in freq_ranges: idx ^= np.logical_and(arr > lower, arr < upper) if invert_sel: idx = ~idx if self.err is not None: self.err = np.ma.MaskedArray(self.err) self.err[:, idx] = np.ma.masked self.data = np.ma.MaskedArray(self.data) self.data[:, idx] = np.ma.masked
[docs] def cut_time(self, lower=-np.inf, upper=np.inf, invert_sel=False) -> "TimeResSpec": """ Remove spectra inside (or outside) of given time-ranges. Parameters ---------- lower : float Lower bound of the region upper : float Upper bound of the region invert_sel : bool Inverts the final selection. Returns ------- : TimeResSpec TimeResSpec containing only the requested regions. """ idx = np.zeros_like(self.t, dtype=bool) arr = self.t idx ^= np.logical_and(arr > lower, arr < upper) if not invert_sel: idx = ~idx if self.err is not None: err = self.err[idx, :] else: err = None return TimeResSpec( self.wavelengths, self.t[idx], self.data[idx, :], err, "nm", disp_freq_unit=self.disp_freq_unit, )
[docs] def scale_and_shift(self, scale: float = 1, t_shift: float = 0, wl_shift: float = 0) -> "TimeResSpec": """ Return a dataset which is scaled and/or has shifted times and frequencies. scale : float Scales the whole dataset by given factor. t_shift : float Shifts the time-axis of an dataset. wl_shift : float Shifts the wavelengths axis and updates the wavenumbers too. Returns ------- TimeResSpec A modified new dataset """ cpy = self.copy() cpy.data *= scale if cpy.err is not None: cpy.err *= scale cpy.t += t_shift cpy.wavelengths += wl_shift cpy.wavenumbers = 1e7 / cpy.wavelengths return cpy
[docs] def mask_times(self, time_ranges, invert_sel=False): """ Mask spectra inside (or outside) of given time-ranges. Parameters ---------- time_ranges : list of (float, float) List containing the edges of the time-regions to keep. invert_sel : bool Invert the selection. Returns ------- : None """ idx = np.zeros_like(self.t, dtype=bool) arr = self.t for (lower, upper) in time_ranges: idx ^= np.logical_and(arr > lower, arr < upper) if not invert_sel: idx = ~idx if self.err is not None: self.err[idx, :].mask = True # self.t = np.ma.MaskedArray(self.t, idx) self.data.mask[:, idx] = True
[docs] def subtract_background(self, n: int = 10): """Subtracts the first n-spectra from the dataset""" self.data -= np.mean(self.data[:n, :], 0, keepdims=True)
[docs] def bin_freqs(self, n: int, freq_unit=None, use_err: bool = True) -> "TimeResSpec": """ Bins down the dataset by averaging over several transients. Parameters ---------- n : int The number of bins. The edges are calculated by np.linspace(freq.min(), freq.max(), n+1). freq_unit : 'nm', 'cm' or None Whether to calculate the bin-borders in frequency- of wavelength space. If `None`, it defaults to `self.disp_freq_unit`. use_err : bool If true, use error for weighting. Returns ------- TimeResSpec Binned down `TimeResSpec` """ # We use the negative of the wavenumbers to make the array sorted if freq_unit is None: freq_unit = self.disp_freq_unit arr = self.wavelengths if freq_unit == "nm" else -self.wavenumbers # Slightly offset edges to include themselves. edges = np.linspace(arr.min() - 0.002, arr.max() + 0.002, n + 1) idx = np.searchsorted(arr, edges) binned = np.empty((self.data.shape[0], n)) binned_wl = np.empty(n) binned_err = np.empty_like(binned) for i in range(n): if self.err is None or not use_err: weights = None else: weights = 1 / self.err[:, idx[i]:idx[i + 1]]**2 vals = self.data[:, idx[i]:idx[i + 1]] binned[:, i] = np.average(vals, 1, weights=weights) if weights is not None: binned_err[:, i] = np.average((vals - binned[:, i, None])**2, 1, weights=weights) binned_wl[i] = np.mean(arr[idx[i]:idx[i + 1]]) if freq_unit == "cm": binned_wl = -binned_wl if self.err is None or not use_err: weights = None return TimeResSpec( binned_wl, self.t, binned, err=binned_err, freq_unit=freq_unit, disp_freq_unit=self.disp_freq_unit, )
[docs] def bin_times(self, n, start_index=0) -> "TimeResSpec": """ Bins down the dataset by binning `n` sequential spectra together. Parameters ---------- n : int How many spectra are binned together. start_index : int Determines the starting index of the binning Returns ------- TimeResSpec Binned down `TimeResSpec` """ out = [] out_t = [] m = len(self.t) for i in range(start_index, m, n): end_idx = min(i + n, m) out.append( sigma_clip(self.data[i:end_idx, :], sigma=2.5, max_iter=1, axis=0).mean(0)) out_t.append(self.t[i:end_idx].mean()) new_data = np.array(out) new_t = np.array(out_t) out_ds = self.copy() out_ds.t = new_t out_ds.data = new_data return out_ds
[docs] def estimate_dispersion(self, heuristic="abs", heuristic_args=(), deg: int = 2, shift_result: float = 0, t_parameter: float = 1.3): """ Estimates the dispersion from a dataset by first applying a heuristic to each channel. The results are than robustly fitted with a polynomial of given order. Parameters ---------- heuristic : {'abs', 'diff', 'gauss_diff', 'max'} or func Determines which heuristic to use on each channel. Can also be a function which follows `func(t, y, *args) and returns a `t0`-value. The heuristics are described in `zero_finding`. heuristic_args : tuple Arguments which are given to the heuristic. deg : int (optional) Degree of the polynomial used to fit the dispersion (defaults to 2). shift_result : float The resulting dispersion curve is shifted by this value. Default 0. t_parameter : float Determines the robustness of the fit. See statsmodels documentation for more info. Returns ------- EstDispResult Tuple containing the dispersion corrected version of the dataset, an array with time-zeros from the heuristic, and the polynomial function resulting from the robust fit. """ func_dict = { "abs": zero_finding.use_first_abs, "diff": zero_finding.use_diff, "gauss_diff": zero_finding.use_gaussian, "max": zero_finding.use_max, } if callable(heuristic): idx = heuristic(self.t, self.data, *heuristic_args) elif heuristic in func_dict: idx = func_dict[heuristic](self.data, *heuristic_args) else: raise ValueError("`heuristic` must be either a callable or" " one of `max`, `abs`, `diff` or `gauss_diff`.") vals, coefs = zero_finding.robust_fit_tz(self.wavenumbers, self.t[idx], deg, t=t_parameter) coefs[-1] += shift_result func = np.poly1d(coefs) result = EstDispResult( correct_ds=self.interpolate_disp(func), tn=self.t[idx] + shift_result, polynomial=func, ) if self.auto_plot: self.plot.plot_disp_result(result) self.disp_result_ = result return result
[docs] def interpolate_disp(self, polyfunc: Union[Callable, Iterable]) -> "TimeResSpec": """ Correct for dispersion by linear interpolation . Parameters ---------- polyfunc : Union[Callable, Iterable] Function which takes wavenumbers and returns time-zeros. Returns ------- TimeResSpec New TimeResSpec where the data is interpolated so that all channels have the same delay point. """ c = self.copy() if callable(polyfunc): zeros = polyfunc(self.wavenumbers) else: zeros = polyfunc ntc = zero_finding.interpol(self, zeros) tmp_tup = dv.tup(self.wavelengths, self.t, self.data) ntc_err = zero_finding.interpol(tmp_tup, zeros) c.data = ntc.data c.err = ntc_err.data return c
[docs] def fit_exp( self, x0, fix_sigma=True, fix_t0=True, fix_last_decay=True, model_coh=False, lower_bound=0.1, verbose=True, use_error=False, fixed_names=None, from_t=None, ): """ Fit a sum of exponentials to the dataset. This function assumes the dataset is already corrected for dispersion. Parameters ---------- x0 : list of floats or array Starting values of the fit. The first value is the estimate of the system response time omega. If `fit_t0` is true, the second float is the guess of the time-zero. All other floats are interpreted as the guessing values for exponential decays. fix_sigma : bool (optional) If to fix the IRF duration sigma. fix_t0 : bool (optional) If to fix the the time-zero. fix_last_decay : bool (optional) Fixes the value of the last tau of the initial guess. It can be used to add a constant by setting the last tau to a large value and fix it. model_coh : bool (optional) If coherent contributions should by modeled. If `True` a gaussian with a width equal the system response time and its derivatives are added to the linear model. lower_bound : float (optional) Lower bound for decay-constants. verbose : bool Prints the results out if True. use_error : bool If the errors are used in the fit. fixed_names : list of str Can be used to fix time-constants from_t: float or None Can be used to cut of early times. """ if from_t is None: ds = self else: ds = self.cut_time(upper=from_t) f = fitter.Fitter(ds, model_coh=model_coh, model_disp=1) if use_error: assert self.err is not None f.weights = 1 / self.err f.res(x0) if fixed_names is None: fixed_names = list() if fix_sigma: fixed_names.append("w") lm_model = f.start_lmfit( x0, fix_long=fix_last_decay, fix_disp=fix_t0, lower_bound=lower_bound, full_model=False, fixed_names=fixed_names, ) ridge_alpha = abs(self.data).max() * 1e-4 f.lsq_method = "ridge" fitter.alpha = ridge_alpha result = lm_model.leastsq() result_tuple = FitExpResult(lm_model, result, f) result_tuple.calculate_stats() self.fit_exp_result_ = result_tuple if verbose: lmfit.fit_report(result) return result_tuple
[docs] def lifetime_density_map(self, taus=None, alpha=1e-4, cv=True, maxiter=30000, **kwargs): """Calculates the LDM from a dataset by regularized regression. Parameters ---------- taus : array or None List with potential decays for building the basis. If `None`, use automatic determination. alpha : float The regularization factor. cv : bool If to apply cross-validation, by default True. """ if taus is None: dt = self.t[self.t_idx(0)] - self.t[self.t_idx(0) - 1] max_t = self.t.max() start = np.floor(np.log10(dt)) end = np.ceil(np.log10(max_t)) taus = np.geomspace(start, end, 5 * (end-start)) result = lifetimemap.start_ltm(self, taus, use_cv=cv, add_const=False, alpha=alpha, add_coh=False, max_iter=30000, **kwargs) result = LDMResult(*result) return result
[docs] def concat_datasets(self, other_ds): """ Merge the dataset with another dataset. The other dataset need to have the same time axis. Parameters ---------- other_ds : TimeResSpec The dataset to merge with Returns ------- TimeResSpec The merged dataset. """ all_wls = np.hstack((self.wavelengths, other_ds.wavelengths)) all_data = np.hstack((self.data, other_ds.data)) if not (self.err is None or other_ds.err is None): all_err = np.hstack((self.err, other_ds.err)) else: all_err = None return TimeResSpec( all_wls, self.t, all_data, err=all_err, freq_unit="nm", disp_freq_unit=self.disp_freq_unit, )
[docs] def merge_nearby_channels(self, distance: float = 8, use_err: bool = False) -> "TimeResSpec": """Merges sequetential channels together if their distance is smaller than given. Parameters ---------- distance : float, optional The minimal distance allowed between two channels. If smaller, they will be merged together, by default 8. use_err : bool Returns ------- TimeResSpec The merged dataset. """ skiplist = [] nwl = self.wavelengths.copy() nspec = self.data.copy() nerr = self.err.copy() if self.err is not None else None weights = 1 / self.err.copy() if self.err is not None else None for i in range(nwl.size - 1): if i in skiplist: continue if abs(nwl[i + 1] - nwl[i]) < distance: if self.err is not None: if self.err is not None and use_err: w = weights[:, i:i + 2]**2 else: w = None mean = np.average(nspec[:, i:i + 2], 1, weights=w) err = np.sqrt( np.average((nspec[:, i:i + 2] - mean[:, None])**2, 1, weights=w)) nspec[:, i] = mean if nerr is not None: nerr[:, i] = err nwl[i] = np.mean(nwl[i:i + 2]) skiplist.append(i + 1) nwl = np.delete(nwl, skiplist) nspec = np.delete(nspec, skiplist, axis=1) if nerr is not None: nerr = np.delete(nerr, skiplist, axis=1) new_ds = self.copy() if nerr is not None and use_err: new_ds.err = nerr else: new_ds.err = None new_ds.wavelengths = nwl new_ds.wavenumbers = 1e7 / nwl new_ds.data = nspec return new_ds
[docs] def apply_filter(self, kind, args) -> 'TimeResSpec': """Apply a filter to the data. Will always return a copy of the data. Returns ------- kind: callable or in ('svd', 'uniform', 'gaussian') What kind of filter to use. Either a string indicating a inbuild filter or a callable. args: any Argument to the filter. Depends on the kind. """ filtered_ds = self.copy() if callable(kind): tup = kind(filtered_ds.data, *args) elif kind == 'svd': tup = filter.svd_filter(filtered_ds, args) elif kind == 'uniform': tup = filter.uniform_filter(filtered_ds, args) elif kind == "gaussian": tup = filter.gaussian_filter(filtered_ds, args) filtered_ds.data = tup.data return filtered_ds
[docs] class PolTRSpec: def __init__(self, para: TimeResSpec, perp: TimeResSpec, iso: Optional[TimeResSpec] = None): """ Class for working with a polazation resolved datasets. Assumes the same frequency and time axis for both polarisations. Parameters ---------- para : TimeResSpec The dataset with parallel pump/probe pol. perp : TimeResSpec The TimeResSpec with perpendicular pump/probe iso : Optional[TimeResSpec] Iso dataset, if none it will be calculated from para and perp. Attributes ---------- plot : PolDataSetPlotter Helper class containing the plotting methods. """ assert para.data.shape == perp.data.shape self.para = para self.perp = perp if iso is None: self.iso = para.copy() self.iso.data = (2 * perp.data + para.data) / 3 else: self.iso = iso self.wavenumbers = para.wavenumbers self.wavelengths = para.wavelengths self.wn, self.wl = self.wavenumbers, self.wavelengths self.t = para.t self.disp_freq_unit = para.disp_freq_unit self.plot = PolTRSpecPlotter(self, self.disp_freq_unit) trs = TimeResSpec self._copy = delegator(self, trs.copy) self.bin_times = delegator(self, trs.bin_times) self.bin_freqs = delegator(self, trs.bin_freqs) self.cut_time = delegator(self, trs.cut_time) self.scale_and_shift = delegator(self, trs.scale_and_shift) self.cut_freq = delegator(self, trs.cut_freq) self.mask_freqs = delegator(self, trs.mask_freqs) self.mask_times = delegator(self, trs.mask_times) self.subtract_background = delegator(self, trs.subtract_background) self.merge_nearby_channels = delegator(self, trs.merge_nearby_channels) self.interpolate_disp = delegator(self, trs.interpolate_disp) self.apply_filter = delegator(self, trs.apply_filter) self.t_idx = para.t_idx self.wn_idx = para.wn_idx self.wl_idx = para.wl_idx
[docs] def copy(self) -> 'PolTRSpec': new_ds = cast( PolTRSpec, self._copy(), ) new_ds.plot.para_ls = self.plot.para_ls new_ds.plot.perp_ls = self.plot.perp_ls return new_ds
[docs] def wl_d(self, wl): idx = self.wl_idx(wl) return self.para[:, idx], self.perp[:, idx]
[docs] def wn_d(self, wn): idx = self.wn_idx(wn) return self.para[:, idx], self.perp[:, idx]
[docs] def t_d(self, t): idx = self.t_idx(t) return self.para.T[:, idx], self.perp.T[:, idx]
[docs] def fit_exp( self, x0, fix_sigma=True, fix_t0=True, fix_last_decay=True, from_t=None, model_coh=False, lower_bound=0.1, use_error=False, fixed_names=None, ) -> FitExpResult: """ Fit a sum of exponentials to the dataset. This function assumes the two datasets is already corrected for dispersion. Parameters ---------- x0 : list of floats or array Starting values of the fit. The first value is the guess of the time-zero. The second value is the estimate of the system response time omega. If `fit_t0` is true, All other floats are interpreted as the guessing values for exponential decays. fix_sigma : bool (optional) If to fix the IRF duration sigma. fix_t0 : bool (optional) If to fix the the time-zero. fix_last_decay : bool (optional) Fixes the value of the last tau of the initial guess. It can be used to add a constant by setting the last tau to a large value and fix it. from_t : float or None If not None, data with t<from_t will be ignored for the fit. model_coh : bool (optional) If coherent contributions should by modeled. If `True` a gaussian with a width equal the system response time and its derivatives are added to the linear model. lower_bound : float (optional) Lower bound for decay-constants. use_error : bool Wether to use the error to weight the residuals fixed_names : list of str Can be used to fix names. """ pa, pe = self.para, self.perp if not from_t is None: pa = pa.cut_time(-np.inf, from_t) pe = pe.cut_time(-np.inf, from_t) all_data = np.hstack((pa.data, pe.data)) all_wls = np.hstack((pa.wavelengths, pe.wavelengths)) all_tup = dv.tup(all_wls, pa.t, all_data) f = fitter.Fitter(all_tup, model_coh=model_coh, model_disp=1) if use_error: all_err = np.hstack((pa.err, pe.err)) f.weights = 1 / all_err f.res(x0) if fixed_names is None: fixed_names = [] if fix_sigma: fixed_names.append("w") lm_model = f.start_lmfit( x0, fix_long=fix_last_decay, fix_disp=fix_t0, lower_bound=lower_bound, full_model=False, fixed_names=fixed_names, ) ridge_alpha = abs(all_data).max() * 1e-4 f.lsq_method = "ridge" fitter.alpha = ridge_alpha result = lm_model.leastsq() self.fit_exp_result_ = FitExpResult(lm_model, result, f) self.fit_exp_result_.calculate_stats() return self.fit_exp_result_
[docs] def save_txt(self, fname, freq_unit="wl"): """ Saves the dataset as a text file. Parameters ---------- fname : str Filename (can include path). This functions adds `_para.txt` and '_perp.txt' for the corresponding dataset to the fname. freq_unit : 'nm' or 'cm' (default 'nm') Which frequency unit is used. """ fname = Path(fname) self.para.save_txt(fname.with_suffix(fname.suffix + '.para.txt'), freq_unit) self.perp.save_txt(fname.with_suffix(fname.suffix + '.perp.txt'), freq_unit) self.iso.save_txt(fname.with_suffix(fname.suffix + '.iso.txt'), freq_unit)
[docs] def concat_datasets(self, other_ds: 'PolTRSpec'): new_ds = self.copy() for i in ['para', 'perp', 'iso']: o = getattr(other_ds, i) setattr(new_ds, i, getattr(self, i).concat_datasets(o)) new_ds.wavelengths = new_ds.para.wavelengths return new_ds
[docs] def delegator(pol_tr: PolTRSpec, method: typing.Callable) -> typing.Callable[..., Optional[PolTRSpec]]: """ Helper function to delegate methods calls from PolTRSpec to the methods of TimeResSpec. Parameters ---------- pol_tr : PolTRSpec method : method of TimeResSpec The method to wrap. Uses function annotations to check if the method returns a new TimeResSpec. """ name = method.__name__ hints = typing.get_type_hints(method) if "return" in hints: do_return = hints["return"] == TimeResSpec else: do_return = False if do_return: @functools.wraps(method) def func(*args, **kwargs) -> PolTRSpec: para = method(pol_tr.para, *args, **kwargs) perp = method(pol_tr.perp, *args, **kwargs) iso = method(pol_tr.iso, *args, **kwargs) return PolTRSpec(para, perp, iso=iso) else: @functools.wraps(method) def func(*args, **kwargs) -> None: para = method(pol_tr.para, *args, **kwargs) perp = method(pol_tr.perp, *args, **kwargs) iso = method(pol_tr.iso, *args, **kwargs) func.__doc__ = method.__doc__ func.__name__ = name return func
[docs] class PlotterMixin: @property
[docs] def x(self): if self.freq_unit == "cm": return self._get_wn() else: return self._get_wl()
[docs] def lbl_spec(self, ax=None, add_legend=True): if ax is None: ax = plt.gca() ax.set_xlabel(ph.freq_label) ax.set_ylabel(ph.sig_label) ax.autoscale(1, "x", 1) ax.axhline(0, color="k", lw=0.5, zorder=1.9) if add_legend: ax.legend(loc='best', ncol=2, title='Delay time') ax.minorticks_on()
[docs] def upsample_spec(self, y, kind='cubic', factor=4): x = self.x assert (y.shape[0] == x.size) inter = interp1d(x, y, kind=kind, assume_sorted=False) fac = factor + 1 diff = np.diff(x) / (fac) new_points = x[:-1, None] + np.arange(1, fac)[None, :] * diff[:, None] xn = np.sort(np.concatenate((x, new_points.ravel()))) return xn, inter(xn)
[docs] def univariate_spline(self, y): if self.dataset.err is not None: w = 1 / self.dataset.err UnivariateSpline(x=self.x, y=y, w=w)
[docs] class TimeResSpecPlotter(PlotterMixin):
[docs] _ds_name = "self.pol_ds.para"
def __init__(self, dataset: TimeResSpec, disp_freq_unit="nm"): """ Class which can Plot a `TimeResSpec` using matplotlib. Parameters ---------- dataset : TimeResSpec The TimeResSpec to work with. disp_freq_unit : {'nm', 'cm'} (optional) The default unit of the plots. To change the unit afterwards, set the attribute directly. """ self.dataset = dataset self.freq_unit = disp_freq_unit
[docs] def _get_wl(self): return self.dataset.wavelengths
[docs] def _get_wn(self): return self.dataset.wavenumbers
[docs] def map(self, symlog=True, equal_limits=True, plot_con=True, con_step=None, con_filter=None, ax=None, **kwargs): """ Plot a colormap of the dataset with optional contour lines. Parameters ---------- symlog : bool Determines if the yscale is symmetric logarithmic. equal_limits : bool If true, it makes to colors symmetric around zeros. Note this also sets the middle of the colormap to zero. Default is `True`. plot_con : bool Plot additional contour lines if `True` (default). con_step : float, array or None Controls the contour-levels. If `con_step` is a float, it is used as the step size between two levels. If it is an array, its elements are the levels. If `None`, it defaults to 20 levels. con_filter : None, int or `TimeResSpec`. Since contours are strongly affected by noise, it can be prefered to filter the dataset before calculating the contours. If `con_filter` is a dataset, the data of that set will be used for the contours. If it is a tuple of int, the data will be filtered with an uniform filter before calculation the contours. If `None`, no data prepossessing will be applied. ax : plt.Axis or None Takes a matplotlib axis. If none, it uses `plt.gca()` to get the current axes. The lines are plotted in this axis. """ if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers cmap = kwargs.pop("colormap", "bwr") if equal_limits: m = np.max(np.abs(ds.data)) vmin, vmax = -m, m else: vmin, vmax = ds.data.max(), ds.data.min() mesh = ax.pcolormesh(x, ds.t, ds.data, shading='auto', vmin=vmin, vmax=vmax, cmap=cmap, **kwargs) if symlog: ax.set_yscale("symlog", linthresh=1) ph.symticks(ax, axis="y") ax.set_ylim(-0.5) plt.colorbar(mesh, ax=ax) con = None if plot_con: if con_step is None: levels = 20 elif isinstance(con_step, np.ndarray): levels = con_step else: # TODO This assumes data has positive and negative elements. pos = np.arange(0, ds.data.max(), con_step) neg = np.arange(0, -ds.data.min(), con_step) levels = np.hstack((-neg[::-1][:-1], pos)) if isinstance(con_filter, TimeResSpec): data = con_filter.data elif con_filter is not None: # must be int or tuple of int if isinstance(con_filter, tuple): data = filter.uniform_filter(ds, con_filter).data else: data = filter.svd_filter(ds, con_filter).data else: data = ds.data con = ax.contour( x, ds.t, data, levels=levels, linestyles="solid", colors="k", linewidths=0.5, ) ph.lbl_map(ax, symlog) if not is_nm: ax.set_xlim(*ax.get_xlim()[::-1]) return mesh, con
[docs] def spec(self, *args, norm=False, ax=None, n_average=0, upsample=1, use_weights=False, offset=0., add_legend=False, **kwargs): """ Plot spectra at given times. Parameters ---------- *args : list or ndarray List of the times where the spectra are plotted. norm : bool or float If true, each spectral will be normalized. If given a float, each spectrum will be normalized to given position. ax : plt.Axis or None. Axis where the spectra are plotted. If none, the current axis will be used. n_average : int For noisy data it may be preferred to average multiple spectra together. This function plots the average of `n_average` spectra around the specific time-points. upsample : int, If upsample is >1, it will plot an upsampled version of the spectrum using cubic spline interplotation. use_weights : bool If given a tuple, the function will plot the average of the given range. use_weights determines if error weights are in calculating the average. offset: float or 'auto' If non-zero, each spectrum will be shifted by 'offset' relatively to the last one. 'auto' is not yet implemented. add_offset : bool Weather to add an legend Returns ------- list of `Lines2D` List containing the Line2D objects belonging to the spectra. """ if len(args) == 1 and isinstance(args[0], list): args = args[0] if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() cur_offset = 0. ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers li = [] for i in args: if isinstance(i, tuple): if ds.err is not None and use_weights: weights = 1 / ds.err[ds.t_idx(i[0]):ds.t_idx(i[1]), :]**2 else: weights = None dat = np.average(ds.data[ds.t_idx(i[0]):ds.t_idx(i[1]), :], weights=weights, axis=0) label = '%.1f ps to %.1f ps' % (i[0], i[1]) else: idx = dv.fi(ds.t, i) if n_average > 0: dat = filter.uniform_filter(ds, (2*n_average + 1, 1)).data[idx, :] elif n_average == 0: dat = ds.data[idx, :] else: raise ValueError("n_average must be an Integer >= 0.") label = ph.time_formatter(ds.t[idx], ph.time_unit) if upsample > 1: x, dat = self.upsample_spec(dat, factor=upsample) if isinstance(norm, bool) and norm: dat = dat / abs(dat).max() elif isinstance(norm, (float, int)) and not isinstance(norm, bool): if norm in (0, 1): warnings.warn( "0 and 1 are not intpreted as a bool here. Use True and False") dat = dat / dat[dv.fi(x, norm)] markevery = None if upsample == 1 else upsample + 1 li += ax.plot(x, dat + cur_offset, markevery=markevery, label=label, **kwargs) cur_offset += offset self.lbl_spec(ax, add_legend) if not is_nm: ax.set_xlim(x.max(), x.min()) return li
[docs] def trans_integrals(self, *args, symlog: bool = True, norm=False, ax=None, **kwargs) -> typing.List[plt.Line2D]: """ Plot the transients of integrated region. The integration will use np.trapz in wavenumber-space. Parameters ---------- args : tuples of floats Tuple of wavenumbers determining the region to be integrated. symlog : bool If to use a symlog scale for the delay-time. norm : bool or float If `true`, normalize to transients. If it is a float, the transients are normalzied to value at the delaytime norm. ax : plt.Axes or None Takes a matplotlib axes. If none, it uses `plt.gca()` to get the current axes. The lines are plotted in this ax kwargs : Further arguments passed to plt.plot Returns ------- list of Line2D List containing the plotted lines. """ if ax is None: ax = plt.gca() ph.ir_mode() ds = self.dataset lines = [] for (a, b) in args: a, b = sorted([a, b]) idx = (a < ds.wavenumbers) & (ds.wavenumbers < b) dat = np.trapz(-ds.data[:, idx], ds.wavenumbers[idx], axis=1) if norm is True: dat = np.sign(dat[np.argmax(abs(dat))]) * dat / abs(dat).max() elif norm is False: pass else: dat = dat / dat[ds.t_idx(norm)] lines.extend(ax.plot(ds.t, dat, label=f"{a: .0f} cm-1 to {b: .0f}", **kwargs)) if symlog: ax.set_xscale("symlog", linthresh=1.0) ph.lbl_trans(ax=ax, use_symlog=symlog) ax.legend(loc="best", ncol=2) ax.set_xlim(right=ds.t.max()) ax.yaxis.set_tick_params(which="minor", left=True) return lines
[docs] def trans(self, *args, symlog=True, norm=False, ax=None, freq_unit="auto", linscale=1, add_legend=True, **kwargs): """ Plot the nearest transients for given frequencies. Parameters ---------- *args : list or ndarray Spectral positions, should be given in the same unit as `self.freq_unit`. symlog : bool Determines if the x-scale is symlog. norm : bool or float If `False`, no normalization is used. If `True`, each transient is divided by the maximum absolute value. If `norm` is a float, all transient are normalized by their signal at the time `norm`. ax : plt.Axes or None Takes a matplotlib axes. If none, it uses `plt.gca()` to get the current axes. The lines are plotted in this axis. freq_unit : 'auto', 'cm' or 'nm' How to interpret the given frequencies. If 'auto' it defaults to the plotters freq_unit. linscale : float If symlog is True, determines the ratio of linear to log-space. add_legend: bool If to add the legend automatically. All other kwargs are forwarded to the plot function. Returns ------- list of Line2D List containing the plotted lines. """ if len(args) == 1 and isinstance(args[0], list): args = args[0] if ax is None: ax = plt.gca() tmp = self.freq_unit if freq_unit == "auto" else freq_unit is_nm = tmp == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers t, d = ds.t, ds.data l, plotted_vals = [], [] for i in args: idx = dv.fi(x, i) dat = d[:, idx] if norm is True: dat = np.sign(dat[np.argmax(abs(dat))]) * dat / abs(dat).max() elif norm is False: pass else: dat = dat / dat[dv.fi(t, norm)] plotted_vals.append(dat) l.extend(ax.plot(t, dat, label="%.0f %s" % (x[idx], ph.freq_unit), **kwargs)) if symlog: ax.set_xscale("symlog", linthresh=1.0, linscale=linscale) ph.lbl_trans(ax=ax, use_symlog=symlog) if add_legend: ax.legend(loc="best", ncol=max(1, len(l) // 3)) ax.set_xlim(right=t.max()) ax.yaxis.set_tick_params(which="minor", left=True) return l
[docs] def trans_fit(self, *args, symlog=True, freq_unit='auto', add_legend=True, ax=None, **kwargs): """ Plot the nearest transients for given frequencies. Parameters ---------- *args : list or ndarray Spectral positions, should be given in the same unit as `self.freq_unit`. symlog : bool Determines if the x-scale is symlog. ax : plt.Axes or None Takes a matplotlib axes. If none, it uses `plt.gca()` to get the current axes. The lines are plotted in this axis. freq_unit : 'auto', 'cm' or 'nm' How to interpret the given frequencies. If 'auto' it defaults to the plotters freq_unit. add_legend: bool If to add the legend automatically. """ ds = self.dataset if ds.fit_exp_result_ is None: raise ValueError("No fit available") if ax is None: ax = plt.gca() tmp = self.freq_unit if freq_unit == "auto" else freq_unit is_nm = tmp == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() x = ds.wavelengths if is_nm else ds.wavenumbers t, mod = ds.fit_exp_result_.fitter.t, ds.fit_exp_result_.fitter.model lines = [] for i in args: idx = dv.fi(x, i) lines.append(ax.plot(t, mod[:, idx], **kwargs)) if symlog: ax.set_xscale("symlog", linthresh=1.0) ph.lbl_trans(ax=ax, use_symlog=symlog) if add_legend: ax.legend(loc="best", ncol=max(1, len(lines) // 3)) return lines
[docs] def overview(self): """ Plots an overview figure. """ is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers fig, axs = plt.subplots(3, 1, figsize=(5, 12), gridspec_kw=dict(height_ratios=(2, 1, 1))) self.map(ax=axs[0]) times = np.hstack((0, np.geomspace(0.1, ds.t.max(), 6))) sp = self.spec(times, ax=axs[1]) freqs = np.unique(np.linspace(x.min(), x.max(), 6)) tr = self.trans(freqs, ax=axs[2]) OverviewPlot = namedtuple("OverviewPlot", "fig axs trans spec") return OverviewPlot(fig, axs, tr, sp)
[docs] def svd(self, n=5): """ Plot the SVD-components of the dataset. Parameters ---------- n : int or list of int Determines the plotted SVD-components. If `n` is an int, it plots the first n components. If `n` is a list of ints, then every number is a SVD-component to be plotted. """ is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers fig, axs = plt.subplots(3, 1, figsize=(4, 5)) u, s, v = np.linalg.svd(ds.data) axs[0].stem(s) axs[0].set_xlim(0, 11) try: len(n) comps = n except TypeError: comps = range(n) for i in comps: axs[1].plot(ds.t, u.T[i], label="%d" % i) axs[2].plot(x, v[i]) ph.lbl_trans(axs[1], use_symlog=True) self.lbl_spec(axs[2])
[docs] def das(self, first_comp=0, ax=None, add_legend=True, **kwargs): """ Plot a DAS, if available. Parameters ---------- fist_comp : int Index of the first shown component, useful if fast components model coherent artefact and should not be shown ax : plt.Axes or None Axes to plot. kwargs : dict Keyword args given to the plot function add_legend: bool If true, add legend automatically. Returns ------- Tuple of (List of Lines2D) """ ds = self.dataset if not hasattr(ds, "fit_exp_result_"): raise ValueError("The PolTRSpec must have successfully fit the " "data first") if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() f = ds.fit_exp_result_.fitter num_exp = f.num_exponentials leg_text = [ph.nsf(i) + " " + ph.time_unit for i in f.last_para[-num_exp:]] if max(f.last_para) > 5 * f.t.max(): leg_text[-1] = "const." l1 = ax.plot(self.x, f.c[:, first_comp:num_exp], **kwargs) for i, l in enumerate(l1): l.set_label(leg_text[i + first_comp]) if add_legend: ax.legend(title="Decay\nConstants") ph.lbl_spec(ax) return l1
[docs] def edas(self, ax=None, legend=True, **kwargs): """ Plot a EDAS, if expontial fit is available. Parameters ---------- ax : plt.Axes or None Axes to plot. kwargs : dict Keyword args given to the plot function Returns ------- Tuple of (List of Lines2D) """ ds = self.dataset if not hasattr(ds, "fit_exp_result_"): raise ValueError("The PolTRSpec must have successfully fit the " "data first") if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() f = ds.fit_exp_result_.fitter num_exp = f.num_exponentials taus = f.last_para[-num_exp:] das = f.c[:, :num_exp] print(np.diff(taus)) if np.any(np.diff(taus) < 0): raise ValueError("EADS expected sorted time-constants") leg_text = [ph.nsf(i) + " " + ph.time_unit for i in f.last_para[-num_exp:]] if max(f.last_para) > 5 * f.t.max(): leg_text[-1] = "const." edas = np.cumsum(das[:, ::-1], axis=1) l1 = ax.plot(self.x, edas[:, ::-1], **kwargs) for i, l in enumerate(l1): l.set_label(leg_text[i]) if legend: ax.legend(title="Species") ph.lbl_spec(ax) return l1
[docs] def interactive(self): """ Generates a jupyter widgets UI for exploring a spectra. """ import ipywidgets as wid from IPython.display import display is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() ds = self.dataset x = ds.wavelengths if is_nm else ds.wavenumbers # fig, ax = plt.subplots() wl_slider = wid.FloatSlider( None, min=x.min(), max=x.max(), step=1, description="Freq (%s)" % self.freq_unit, ) def func(x): # ax.cla() self.trans([x]) # plt.show() ui = wid.interactive(func, x=wl_slider, continuous_update=False) display(ui) return ui
[docs] def plot_disp_result(self, result: EstDispResult): """Visualize the result of a dispersion correction, creates a figure""" fig, (ax1, ax2) = plt.subplots(2, 1, sharex="col", figsize=(3, 4)) ds = self.dataset tmp_unit = self.freq_unit, result.correct_ds.plot.freq_unit self.freq_unit = "cm" result.correct_ds.plot.freq_unit = "cm" self.map(symlog=False, plot_con=False, ax=ax1) ylim = max(ds.t.min(), -2), min(2, ds.t.max()) ax1.set_ylim(*ylim) ax1.plot(ds.wavenumbers, result.tn) ax1.plot(ds.wavenumbers, result.polynomial(ds.wavenumbers)) result.correct_ds.map(symlog=True, con_filter=3, con_step=None) self.freq_unit = tmp_unit[0] result.correct_ds.plot.freq_unit = tmp_unit[1]
[docs] class PolTRSpecPlotter(PlotterMixin):
[docs] perp_ls = dict(marker='s', markersize=3, linewidth=1, markerfacecolor='w')
[docs] para_ls = dict(marker='o', markersize=3, linewidth=1)
def __init__(self, pol_dataset: PolTRSpec, disp_freq_unit=None): """ Plotting commands for a PolTRSpec Parameters ---------- pol_dataset : PolTRSpec The Data disp_freq_unit : {'nm', 'cm'} (optional) The default unit of the plots. To change the unit afterwards, set the attribute directly. """ self.pol_ds = pol_dataset if disp_freq_unit is not None: self.freq_unit = disp_freq_unit self.perp_ls = PolTRSpecPlotter.perp_ls.copy() self.para_ls = PolTRSpecPlotter.para_ls.copy()
[docs] def _get_wl(self): return self.pol_ds.para.wavelengths
[docs] def _get_wn(self): return self.pol_ds.para.wavenumbers
[docs] def spec(self, *times, norm=False, ax=None, n_average=0, add_legend=True, **kwargs): """ Plot spectra at given times. Parameters ---------- *times : list or ndarray List of the times where the spectra are plotted. norm : bool If true, each spectral will be normalized. ax : plt.Axis or None. Axis where the spectra are plotted. If none, the current axis will be used. n_average : int For noisy data it may be prefered to average multiple spectra together. This function plots the average of `n_average` spectra around the specific time-points. upsample : int If >1, upsample the spectrum using cubic interpolation. add_legend : bool Add legend automatically Returns ------- tuple of (List of `Lines2D`) List containing the Line2D objects belonging to the spectra. """ if ax is None: ax = plt.gca() pa, pe = self.pol_ds.para, self.pol_ds.perp l1 = pa.plot.spec(*times, norm=norm, ax=ax, n_average=n_average, **self.para_ls, **kwargs) l2 = pe.plot.spec(*times, norm=norm, ax=ax, n_average=n_average, **self.perp_ls, **kwargs) dv.equal_color(l1, l2) colored_lines = [ Line2D([0], [0], color=l.get_color(), label=l.get_label()) for l in l1 ] pol_lines = [ Line2D([0], [0], color='0.3', label=r'$\parallel$-pol.', **self.para_ls), Line2D([0], [0], color='0.3', label=r'$\perp$-pol.', **self.perp_ls) ] all_lines = colored_lines + pol_lines self.lbl_spec(ax, add_legend=False) if add_legend: ax.legend(all_lines, [l.get_label() for l in all_lines]) return l1, l2
[docs] def trans(self, *args, symlog=True, norm=False, ax=None, add_legend=True, **kwargs): """ Plot the nearest transients for given frequencies. Parameters ---------- wls : list or ndarray Spectral positions, should be given in the same unit as `self.freq_unit`. symlog : bool Determines if the x-scale is symlog. norm : bool or float If `False`, no normalization is used. If `True`, each transient is divided by the maximum absolute value. If `norm` is a float, all transient are normalized by their signal at the time `norm`. ax : plt.Axes or None Takes a matplotlib axes. If none, it uses `plt.gca()` to get the current axes. The lines are plotted in this axis. add_legend: bool If true, it will add the legend automatically. All other kwargs are forwarded to the plot function. Returns ------- list of Line2D Tuple of lists containing the plotted lines. """ if len(args) == 1 and isinstance(args[0], list): args = args[0] if ax is None: ax = plt.gca() pa, pe = self.pol_ds.para, self.pol_ds.perp # Avoid duplicated keywords duplicated_para = {} duplicated_perp = {} for k in kwargs: if k in self.para_ls: self.para_ls.pop(k) duplicated_para[k] = kwargs[k] if k in self.perp_ls: self.perp_ls.pop(k) duplicated_perp[k] = kwargs[k] l1 = pa.plot.trans(*args, symlog=symlog, norm=norm, ax=ax, add_legend=False, **kwargs, **self.para_ls) l2 = pe.plot.trans(*args, symlog=symlog, norm=norm, ax=ax, add_legend=False, **kwargs, **self.perp_ls) self.para_ls.update(**duplicated_para) self.para_ls.update(**duplicated_perp) dv.equal_color(l1, l2) colored_lines = [ Line2D([0], [0], color=l.get_color(), label=l.get_label()) for l in l1 ] pol_lines = [ Line2D([0], [0], color='0.3', label=r'$\parallel$-pol.', **self.para_ls), Line2D([0], [0], color='0.3', label=r'$\perp$-pol.', **self.perp_ls) ] all_lines = colored_lines + pol_lines if add_legend: ax.legend(all_lines, [l.get_label() for l in all_lines]) return l1, l2
[docs] def das(self, ax=None, plot_first_das=True, **kwargs): """ Plot a DAS, if available. Parameters ---------- ax : plt.Axes or None Axes to plot. plot_first_das : bool If true, the first DAS is omitted. This is useful, when the first component is very fast and only modeles coherent contributions. kwargs : dict Keyword args given to the plot function Returns ------- Tuple of (List of Lines2D) """ ds = self.pol_ds if not hasattr(self.pol_ds, "fit_exp_result_"): raise ValueError("The PolTRSpec must have successfully fit the " "data") if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() f = ds.fit_exp_result_.fitter num_exp = f.num_exponentials leg_text = [ph.nsf(i) + " " + ph.time_unit for i in f.last_para[-num_exp:]] if max(f.last_para) > 5 * f.t.max(): leg_text[-1] = "const." n = ds.para.wavelengths.size x = ds.para.wavelengths if is_nm else ds.para.wavenumbers start = 0 if plot_first_das else 1 palines = [] pelines = [] for c, i in enumerate(range(start, num_exp)): l1 = ax.plot(x, f.c[:n, i], **kwargs, **self.para_ls, label=leg_text[i], color='C%d' % c) l2 = ax.plot(x, f.c[n:, i], **kwargs, **self.perp_ls) dv.equal_color(l1, l2) palines += l1 pelines += l2 ph.lbl_spec(ax=ax) ncol = max(num_exp // 3, 1) ax.legend(title="Decay\nConstants", ncol=ncol) return palines, pelines
[docs] def edas(self, ax=None, *, add_legend=True, **kwargs): """ Plots a SAS (also called EDAS), if available. Parameters ---------- ax : plt.Axes or None Axes to plot. kwargs : dict Keyword args given to the plot function Returns ------- Tuple of (List of Lines2D) """ ds = self.pol_ds if not hasattr(self.pol_ds, "fit_exp_result_"): raise ValueError("The PolTRSpec must have successfully fit the " "data") if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() f = ds.fit_exp_result_.fitter num_exp = f.num_exponentials taus = f.last_para[-num_exp:] leg_text = [ph.nsf(i) + " " + ph.time_unit for i in taus] if max(f.last_para) > 5 * f.t.max(): leg_text[-1] = "const." n = ds.para.wavelengths.size x = ds.para.wavelengths if is_nm else ds.para.wavenumbers start = 0 if any(np.diff(taus) < 0): raise ValueError("SAS assumes sorted taus") das = f.c[:, :num_exp] edas_pa = np.cumsum(das[:n, ::-1], axis=1)[:, ::-1] edas_pe = np.cumsum(das[n:, ::-1], axis=1)[:, ::-1] palines = [] pelines = [] for c, i in enumerate(range(start, num_exp)): l1 = ax.plot(x, edas_pa.T[i], **kwargs, **self.para_ls, label=leg_text[i], color='C%d' % c) l2 = ax.plot(x, edas_pe.T[i], **kwargs, **self.perp_ls) dv.equal_color(l1, l2) palines += l1 pelines += l2 ph.lbl_spec(ax=ax) ncol = max(num_exp // 3, 1) if add_legend: ax.legend(title="EDAS\nConstants", ncol=ncol) return palines, pelines
[docs] def sas(self, model: Model, QYs: Dict[str, float] = {}, y0: Optional[np.ndarray] = None, ax=None, *, add_legend=True, **kwargs): """ Plots a SAS (also called EDAS), if available. Parameters ---------- mode: Model Kinetic model yield: Dict with the yields ax : plt.Axes or None Axes to plot. kwargs : dict Keyword args given to the plot function Returns ------- Tuple of (List of Lines2D) """ ds = self.pol_ds if not hasattr(self.pol_ds, "fit_exp_result_"): raise ValueError("The PolTRSpec must have successfully fit the " "data") if ax is None: ax = plt.gca() is_nm = self.freq_unit == "nm" if is_nm: ph.vis_mode() else: ph.ir_mode() f = ds.fit_exp_result_.fitter num_exp = f.num_exponentials taus = f.last_para[-num_exp:] if any(np.diff(taus) < 0): raise ValueError("SAS assumes sorted taus") n = ds.para.wavelengths.size sas, _ = ds.fit_exp_result_.make_sas(model, QYs, y0) sas_pa = sas[:, :n] sas_pe = sas[:, n:] leg_text = list(map(str, model.get_compartments())) n = ds.para.wavelengths.size x = ds.para.wavelengths if is_nm else ds.para.wavenumbers palines = [] pelines = [] for c, i in enumerate(range(num_exp)): l1 = ax.plot(x, sas_pa[i], **kwargs, **self.para_ls, label=leg_text[i], color='C%d' % c) l2 = ax.plot(x, sas_pe[i], **kwargs, **self.perp_ls) dv.equal_color(l1, l2) palines += l1 pelines += l2 ph.lbl_spec(ax=ax) ncol = max(num_exp // 3, 1) if add_legend: ax.legend(title="SAS\nConstants", ncol=ncol) return palines, pelines
[docs] def trans_anisotropy(self, *wls: float, symlog: bool = True, ax: Optional[plt.Axes] = None, freq_unit=typing.Literal['auto', 'nm', 'cm'], mode: typing.Literal['aniso', 'dichro']): """ Plots the anisotropy over time for given frequencies. Parameters ---------- wls :floats Which frequencies are plotted. symlog : bool Use symlog scale ax : plt.Axes or None Matplotlib Axes, if `None`, defaults to `plt.gca()`. freq_unit: ['auto', 'nm', 'cm'] Unit of the frequecies. mode: ['aniso', 'dichro'] Plot anisotropy or dichroism. Returns ------- : list of Line2D List with the line objects. """ if ax is None: ax = plt.gca() ds = self.pol_ds tmp = self.freq_unit if freq_unit == "auto" else freq_unit is_nm = tmp == "nm" x = ds.wavelengths if is_nm else ds.wavenumbers if is_nm: ph.vis_mode() else: ph.ir_mode() l = [] for i in wls: idx = dv.fi(x, i) pa, pe = ds.para.data[:, idx], ds.perp.data[:, idx] aniso = (pa-pe) / (2*pe + pa) dichro = pa / pe if mode == "aniso": data = aniso elif mode == "dichro": data = dichro l += ax.plot(ds.para.t, data, label="%.0f %s" % (x[idx], ph.freq_unit)) ph.lbl_trans(use_symlog=symlog) if symlog: ax.set_xscale("symlog") ax.set_xlim(-1) return l
[docs] class DataSetInteractiveViewer: def __init__(self, dataset, fig_kws=None): """ Class showing a interactive matplotlib window for exploring a dataset. """ if fig_kws is None: fig_kws = {} self.dataset = ds = dataset self.figure, axs = plt.subplots(3, 1, **fig_kws) self.ax_img, self.ax_trans, self.ax_spec = axs self.ax_img.pcolormesh(dataset.wn, dataset.t, dataset.data) self.ax_img.set_yscale("symlog", linscale=1) ph.lbl_spec(self.ax_spec) ph.lbl_trans(self.ax_trans) self.ax_trans.set_xscale('symlog', linthresh=1) self.trans_line = self.ax_trans.plot([])[0] self.spec_line = self.ax_spec.plot([])[0] self.ax_trans.set_xlim(-1, ds.t.max()) self.ax_trans.set_ylim(ds.data.min(), ds.data.max()) self.ax_spec.set_ylim(ds.data.min(), ds.data.max()) self.ax_spec.set_xlim(ds.wn.max(), ds.wn.min()) self._events = [] self.init_events()
[docs] def init_events(self): """Connect mpl events""" connect = self.figure.canvas.mpl_connect self._events.append(connect("motion_notify_event", self.update_lines))
[docs] def update_lines(self, event): """If the mouse cursor is over the 2D image, update the dynamic transient and spectrum""" print(event.inaxes) self.ax_img.set_title(self.ax_spec) if event.inaxes is self.ax_img: ds = self.dataset print(event) wn_idx = ds.wn_idx(event.xdata) t_idx = ds.t_idx(event.ydata) spec = ds.data[t_idx, :] trans = ds.data[:, wn_idx] self.ax_trans.set_title('%.1f' % ds.wn[wn_idx]) self.ax_trans.set_ylim(trans.min(), trans.max()) self.ax_spec.set_title('%.1f' % ds.t[t_idx]) self.ax_spec.set_ylim(spec.min(), spec.max()) self.trans_line.set_data(ds.t, ds.data[:, wn_idx]) self.spec_line.set_data(ds.wn, ds.data[t_idx, :]) self.figure.canvas.draw()
@attr.s(auto_attribs=True)
[docs] class SasViewer:
[docs] fit_result: FitExpResult
[docs] model: Model
[docs] fig: plt.Figure
[docs] lines: List[Line2D]