Source code for impedance_agent.fitters.drt

# src/fitters/drt.py
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
from scipy.signal import find_peaks, peak_widths
from scipy.optimize import least_squares
from ..core.models import DRTResult
import logging


[docs] class DRTFitter: """ Distribution of Relaxation Times (DRT) analysis for electrochemical impedance data. This class implements the combined Tikhonov regularization and projected gradient method for calculating DRT from impedance data as described in: "PEM fuel cell distribution of relaxation times: a method for the calculation and behavior of an oxygen transport peak" The fundamental DRT equation relating impedance to the distribution function is: .. math:: Z(\\omega) - R_\\infty = R_{pol}\\int_0^\\infty \\frac{\\gamma(\\tau)d\\tau}{1 + i\\omega\\tau} With normalization condition: .. math:: \\int_0^\\infty \\gamma(\\tau)d\\tau = 1 Parameters ---------- zexp_re : np.ndarray Real part of experimental impedance zexp_im : np.ndarray Imaginary part of experimental impedance omg : np.ndarray Angular frequencies lam_t0 : float Initial Tikhonov regularization parameter lam_pg0 : float Initial projected gradient parameter lower_bounds : np.ndarray Lower bounds for regularization parameters upper_bounds : np.ndarray Upper bounds for regularization parameters mode : str, optional 'real' or 'imag' for using real or imaginary part of impedance (default: 'real') """
[docs] def __init__( self, zexp_re, zexp_im, omg, lam_t0, lam_pg0, lower_bounds, upper_bounds, mode="real", ): self.logger = logging.getLogger(__name__) jax.config.update("jax_enable_x64", True) # Calculate scale factor self.rpol = zexp_re[-1] - zexp_re[0] if abs(self.rpol) < 1e-10: self.rpol = 1.0 self.logger.warning("Rpol near zero, using default value of 1.0") # Store normalized data self.zexp_re = zexp_re self.zexp_im = zexp_im self.zexp_re_norm = zexp_re / self.rpol self.zexp_im_norm = zexp_im / self.rpol self.omg = omg self.mode = mode self.lam_t0 = lam_t0 self.lam_pg0 = lam_pg0 self.lower_bounds = lower_bounds self.upper_bounds = upper_bounds # Updated iteration settings self.niter = 80 self.flagiter = 0 # Initialize arrays self.tau = 1.0 / self.omg self.ln_tau = jnp.log(self.tau) self.d_ln_tau = self._create_dmesh(self.ln_tau) self.d_tau = self._create_dmesh(self.tau) self.id_matrix = jnp.identity(self.omg.size, dtype=jnp.float64) self.a_matrix = jnp.zeros((self.omg.size, self.omg.size), dtype=jnp.float64) self.logger.info(f"Initialized DRT fitter with mode={mode}") self.logger.debug(f"Initial parameters: λT={lam_t0:.2e}, λPG={lam_pg0:.2e}") # Build matrices and get initial guess self._create_tikhonov_matrix() self.gfun_init = self.tikh_solver( self.lam_t0, self.a_mat_t_a, self.b_rhs, self.id_matrix )
[docs] def compute_normalized_residuals( self, Z_fit: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute normalized residuals using impedance modulus for normalization. .. math:: r_{real} = \\frac{Z_{fit,real} - Z_{exp,real}}{\\sqrt{Z_{exp,real}^2 + Z_{exp,imag}^2}} r_{imag} = \\frac{Z_{fit,imag} - Z_{exp,imag}}{\\sqrt{Z_{exp,real}^2 + Z_{exp,imag}^2}} Parameters ---------- Z_fit : np.ndarray Complex array containing the fitted impedance values Returns ------- Tuple[np.ndarray, np.ndarray] Tuple containing: - residuals_real: Normalized residuals of the real component - residuals_imag: Normalized residuals of the imaginary component """ Z_mod = np.sqrt(self.zexp_re**2 + self.zexp_im**2) residuals_real = (Z_fit.real - self.zexp_re) / Z_mod residuals_imag = (Z_fit.imag - self.zexp_im) / Z_mod return residuals_real, residuals_imag
def _create_dmesh(self, grid): """ Creates a mesh spacing array for the given grid using central differences. .. math:: dh_j = \\begin{cases} 0.5(grid_{j+1} - grid_{j-1}) & \\text{for } j = 1,\\ldots,N-2 \\\\ 0.5(grid_1 - grid_0) & \\text{for } j = 0 \\\\ 0.5(grid_{N-1} - grid_{N-2}) & \\text{for } j = N-1 \\end{cases} Parameters ---------- grid : np.ndarray Array containing grid points Returns ------- np.ndarray Array containing mesh spacing values """ dh = jnp.zeros(self.omg.size, dtype=jnp.float64) for j in range(1, self.omg.size - 1): dh = dh.at[j].set(0.5 * (grid[j + 1] - grid[j - 1])) dh = dh.at[0].set(0.5 * (grid[1] - grid[0])) dh = dh.at[-1].set(0.5 * (grid[-1] - grid[-2])) return dh def _am_row(self, omega_value): """ Builds a single row of the system matrix for a given frequency. .. math:: A_{m,n} = \\begin{cases} \\frac{d\\ln\\tau}{1 + (\\omega\\tau)^2} & \\text{for real mode} \\\\ \\frac{\\omega\\tau d\\ln\\tau}{1 + (\\omega\\tau)^2} & \\text{for imaginary mode} \\end{cases} Parameters ---------- omega_value : float Angular frequency value Returns ------- np.ndarray Row of the system matrix """ prod = omega_value * self.tau if self.mode == "real": return self.d_ln_tau / (1.0 + prod**2) else: return prod * self.d_ln_tau / (1.0 + prod**2) def _create_tikhonov_matrix(self): """ Builds the system matrix A, the Tikhonov matrix A^T A + λT I, and RHS. .. math:: A^TA\\vec{\\gamma} + \\lambda_T I\\vec{\\gamma} = A^T\\vec{b} Where b is either Z' or Z'' depending on mode. """ self.logger.debug("Building Tikhonov matrix") self.a_matrix = jax.vmap(self._am_row)(self.omg) self.a_matrix_t = self.a_matrix.transpose() self.a_mat_t_a = jnp.matmul(self.a_matrix_t, self.a_matrix) self.a_tikh = self.a_mat_t_a + self.lam_t0 * self.id_matrix if self.mode == "real": self.b_rhs = jnp.matmul(self.a_matrix_t, self.zexp_re_norm) else: self.b_rhs = jnp.matmul(self.a_matrix_t, self.zexp_im_norm)
[docs] def tikh_solver(self, lam_t, a_mat_t_a, b_rhs, id_matrix): """ Solve the Tikhonov-regularized equation using least squares. .. math:: (A^TA + \\lambda_T I)\\vec{\\gamma} = A^T\\vec{b} Parameters ---------- lam_t : float Tikhonov regularization parameter a_mat_t_a : np.ndarray A^T A matrix b_rhs : np.ndarray Right-hand side vector id_matrix : np.ndarray Identity matrix Returns ------- np.ndarray Solution vector γ """ lhs_matrix = a_mat_t_a + lam_t * id_matrix sol, residuals, rank, sv = jnp.linalg.lstsq(lhs_matrix, b_rhs, rcond=0) return sol
[docs] def objective_function(self, g_vector, lhs_matrix, b_rhs): """ Objective function for projected gradient optimization. .. math:: f(\\vec{\\gamma}) = \\|(A^TA + \\lambda_T I)\\vec{\\gamma} - A^T\\vec{b}\\|^2 Parameters ---------- g_vector : np.ndarray Current γ vector lhs_matrix : np.ndarray Left-hand side matrix b_rhs : np.ndarray Right-hand side vector Returns ------- float Objective function value """ residuals = jnp.matmul(lhs_matrix, g_vector) - b_rhs return jnp.sum(residuals**2)
[docs] def find_lambda(self): """ Scans over lambda_T range to find optimal regularization parameters. Returns ------- Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] - resid: Residual norms - solnorm: Solution norms - lam_t_arr: Tikhonov parameter values - lam_pg_arr: Projected gradient parameter values """ self.logger.debug("Scanning lambda parameter range") kmax, lam_val = 25, 1e-25 solnorm = jnp.zeros(kmax, dtype=jnp.float64) resid = jnp.zeros(kmax, dtype=jnp.float64) lam_t_arr = jnp.zeros(kmax, dtype=jnp.float64) lam_pg_arr = jnp.zeros(kmax, dtype=jnp.float64) for k in range(kmax): lam_val = lam_val * 10 lam_t_arr = lam_t_arr.at[k].set(lam_val) gfun = self.tikh_solver(lam_val, self.a_mat_t_a, self.b_rhs, self.id_matrix) resid = resid.at[k].set(self.residual_norm(gfun)) solnorm = solnorm.at[k].set(jnp.sqrt(jnp.sum(gfun**2))) lam_pg_arr = lam_pg_arr.at[k].set(1.0 / jnp.linalg.norm(self.a_tikh)) return resid, solnorm, lam_t_arr, lam_pg_arr
[docs] def residual_norm(self, g_vector): """ Calculate the norm of the residual using (A^T A)g - A^T b. Parameters ---------- g_vector : np.ndarray Current γ vector Returns ------- float Residual norm """ lhs_prod = jnp.matmul(self.a_mat_t_a, g_vector) norm_res = jnp.sqrt(jnp.sum((lhs_prod - self.b_rhs) ** 2)) return norm_res
[docs] def encode(self, p, lb, ub): """ Converts external parameters to internal parameters in log10 scale. .. math:: p_{int} = \\log_{10}\\left(\\frac{p - lb}{1 - p/ub}\\right) """ return jnp.log10((p - lb) / (1 - p / ub))
[docs] def decode(self, p, lb, ub): """ Converts internal parameters to external parameters. .. math:: p_{ext} = \\frac{lb + 10^p}{1 + 10^p/ub} """ return (lb + 10**p) / (1 + 10**p / ub)
[docs] def jacobian_lsq( self, pvec, lhs_matrix, a_mat_t_a, b_rhs, d_ln_tau, id_matrix, lb, ub ): """Compute the Jacobian of the Tikhonov residual function.""" return jax.jacobian(self.tikh_residual)( jnp.array(pvec), lhs_matrix, a_mat_t_a, b_rhs, d_ln_tau, id_matrix, lb, ub )
[docs] def pg_solver(self, lamvec, lhs_matrix, a_mat_t_a, b_rhs, d_ln_tau, id_matrix): """ Projected gradient solver using Tikhonov solution as initial guess. Implements the iterations: .. math:: \\vec{\\gamma}^{k+1} = P_+[\\vec{\\gamma}^k - \\lambda_{pg}((A^TA + \\lambda_TI)\\vec{\\gamma}^k - A^T\\vec{b})] where P_+ is the projection onto non-negative orthant. """ lam_t, lam_pg = lamvec g_init = self.tikh_solver(lam_t, a_mat_t_a, b_rhs, id_matrix) lhs_matrix_new = a_mat_t_a + lam_t * id_matrix pg = ProjectedGradient( fun=jax.jit(self.objective_function), projection=projection_non_negative, tol=1e-8, maxiter=self.niter * 1000, implicit_diff=True, jit=True, ) self.logger.debug("Running projected gradient optimization") solution = pg.run(init_params=g_init, lhs_matrix=lhs_matrix_new, b_rhs=b_rhs) r_poly = jnp.sum(solution.params * d_ln_tau) return solution.params, r_poly, solution.state.iter_num
[docs] def tikh_residual( self, lamvec_log, lhs_matrix, a_mat_t_a, b_rhs, d_ln_tau, id_matrix, lb, ub ): """Compute Tikhonov residual for given regularization parameters.""" lamvec_norm = self.decode(lamvec_log, lb, ub) g_vector, rpoly, iterations = self.pg_solver( lamvec_norm, lhs_matrix, a_mat_t_a, b_rhs, d_ln_tau, id_matrix ) resid = jnp.matmul(self.a_tikh, g_vector) - self.b_rhs return resid
[docs] def rpol_peaks(self, g_vector): """ Find peaks in the DRT spectrum and calculate their parameters. Parameters ---------- g_vector : np.ndarray DRT solution vector Returns ------- np.ndarray Array containing peak parameters (frequencies and polarizations) """ self.logger.debug("Analyzing DRT peaks") g_np = np.array(g_vector, copy=True) peaks, _ = find_peaks(g_np, prominence=0.01) widths = peak_widths(g_np, peaks, rel_height=1.0) integrals = jnp.zeros(peaks.size, dtype=jnp.float64) for n in range(peaks.size): low_bound = int(widths[2][n]) up_bound = int(widths[3][n]) integrals = integrals.at[n].set( jnp.sum( g_vector[low_bound:up_bound] * self.d_ln_tau[low_bound:up_bound] ) ) peak_params = jnp.zeros((2, peaks.size), dtype=jnp.float64) peak_params = peak_params.at[0, :].set( jnp.flip(1.0 / (2.0 * jnp.pi * self.tau[peaks])) ) peak_params = peak_params.at[1, :].set(jnp.flip(integrals)) return peak_params
[docs] def z_model_imre(self, g_vector): """ Calculate the model impedance from the DRT solution. .. math:: Z(\\omega) = R_{pol}\\int_0^\\infty \\frac{\\gamma(\\tau)}{1 + i\\omega\\tau}d\\tau """ z_mod = jnp.zeros(self.omg.size, dtype=jnp.float64) for i in range(self.omg.size): prod = self.omg[i] * self.tau if self.mode == "real": integrand = g_vector / (1.0 + prod**2) else: integrand = prod * g_vector / (1.0 + prod**2) z_mod = z_mod.at[i].set(jnp.sum(self.d_ln_tau * integrand)) return jnp.flip(self.rpol * z_mod)
[docs] def fit(self) -> Optional[DRTResult]: """ Perform the complete DRT fitting process. Returns ------- Optional[DRTResult] Complete DRT analysis results or None if fitting fails Notes ----- 1. Finds optimal regularization parameters 2. Performs Tikhonov regularization with projected gradient 3. Analyzes peaks in the DRT spectrum 4. Calculates fitted impedance and residuals """ try: self.logger.info("Starting DRT analysis") # Find optimal lambda parameters resid, solnorm, arr_lam_t, arr_lam_pg = self.find_lambda() self.logger.debug("Completed lambda parameter search") # Prepare for least squares lamvec_init = jnp.array([self.lam_t0, self.lam_pg0], dtype=jnp.float64) lamvec_init_log = self.encode( lamvec_init, self.lower_bounds, self.upper_bounds ) # Perform optimization res_parm = least_squares( jax.jit(self.tikh_residual), lamvec_init_log, method="lm", jac=self.jacobian_lsq, args=( self.a_tikh, self.a_mat_t_a, self.b_rhs, self.d_ln_tau, self.id_matrix, self.lower_bounds, self.upper_bounds, ), ) final_lamvec = self.decode(res_parm.x, self.lower_bounds, self.upper_bounds) self.logger.info( f"Final parameters: λT={final_lamvec[0]:.2e}, λPG={final_lamvec[1]:.2e}" ) # Get final DRT distribution gfun_final, rpoly, n_iters = self.pg_solver( final_lamvec, self.a_tikh, self.a_mat_t_a, self.b_rhs, self.d_ln_tau, self.id_matrix, ) self.logger.info(f"Completed projected gradient in {n_iters} iterations") self.logger.info(f"Rpol = {self.rpol:.6g}, Final rpoly = {rpoly:.6g}") # Calculate residuals res_init, lhs_init = self.tikh_residual_norm( self.gfun_init, self.lam_t0, self.a_mat_t_a, self.b_rhs, self.id_matrix ) res_fin, lhs_fin = self.tikh_residual_norm( gfun_final, final_lamvec[0], self.a_mat_t_a, self.b_rhs, self.id_matrix ) self.logger.info( f"Residuals: initial = {res_init:.6e}, final = {res_fin:.6e}" ) if res_parm.status > 0: self.logger.info( f"Optimization successful: {res_parm.njev} Jacobian evaluations" ) if self.flagiter == 1: self.logger.warning( "Maximum iteration limit reached in projected gradient" ) # Find peaks peak_params = self.rpol_peaks(gfun_final) if peak_params.size > 0: peak_freqs = [float(f) for f in peak_params[0]] peak_pols = [float(p) for p in peak_params[1]] self.logger.info("Identified DRT peaks:") for f, p in zip(peak_freqs, peak_pols): self.logger.info(f" f = {f:.2f} Hz, polarization = {p:.5f}") else: self.logger.warning("No peaks detected in DRT spectrum") peak_freqs, peak_pols = [], [] # Convert JAX arrays to numpy arrays tau = np.array(self.tau) gamma = np.array(gfun_final) # Calculate model impedance Z_fit = self.z_model_imre(gfun_final) Z_fit = np.array(Z_fit) # Convert to numpy array # Calculate normalized residuals residuals_real, residuals_imag = self.compute_normalized_residuals(Z_fit) # Validate results if np.any(np.isnan(gamma)) or np.any(np.isinf(gamma)): raise ValueError("Invalid values detected in final DRT solution") # Create result object result = DRTResult( tau=tau, gamma=gamma, peak_frequencies=peak_freqs, peak_polarizations=peak_pols, regularization_param=float(final_lamvec[0]), residual=float(rpoly), Z_fit=Z_fit, residuals_real=residuals_real, residuals_imag=residuals_imag, ) self.logger.debug(f"Created DRTResult object with {len(peak_freqs)} peaks") return result except Exception: self.logger.error("DRT fitting failed", exc_info=True) return None
[docs] def tikh_residual_norm(self, g_vector, lam_t, a_mat_t_a, b_rhs, id_matrix): """ Compute norm of Tikhonov residual and LHS norm. .. math:: \\|res\\| = \\|(A^TA + \\lambda_T I)\\vec{\\gamma} - A^T\\vec{b}\\| \\|lhs\\| = \\|(A^TA + \\lambda_T I)\\vec{\\gamma}\\| Parameters ---------- g_vector : np.ndarray Current γ vector lam_t : float Tikhonov regularization parameter a_mat_t_a : np.ndarray A^T A matrix b_rhs : np.ndarray Right-hand side vector id_matrix : np.ndarray Identity matrix Returns ------- Tuple[float, float] Residual norm and LHS norm """ lhs_matrix = a_mat_t_a + lam_t * id_matrix lhs_prod = jnp.matmul(lhs_matrix, g_vector) sum_res = jnp.sqrt(jnp.sum((lhs_prod - b_rhs) ** 2)) sum_lhs = jnp.sqrt(jnp.sum(lhs_prod**2)) return sum_res, sum_lhs