Source code for impedance_agent.fitters.ecm

# src/fitters/ecm.py
import jax
import jaxopt
import jax.numpy as jnp
import numpy as np
import logging
from typing import Optional, Tuple
from ..core.models import ImpedanceData, FitResult, FitQualityMetrics


[docs] class ECMFitter: """ Equivalent Circuit Model (ECM) fitting for electrochemical impedance data. This class implements least squares optimization with bounded parameters for fitting equivalent circuit models to impedance data. The fitting process uses weighted residuals and supports different weighting schemes. The fundamental equation for the weighted sum of squared residuals is: .. math:: WRSS = \\sum_{i=1}^N \\frac{(Z_{\\mathrm{exp},i} - Z_{\\mathrm{model},i})^2}{\\sigma_i^2} Supported weighting schemes: .. math:: \\sigma_i = \\begin{cases} 1 & \\text{for unit weighting} \\\\ |Z_{\\mathrm{exp},i}| & \\text{for proportional weighting} \\\\ \\sqrt{(\\mathrm{Re}(Z_{\\mathrm{exp},i}))^2 + (\\mathrm{Im}(Z_{\\mathrm{exp},i}))^2} & \\text{for modulus weighting} \\end{cases} Parameters ---------- model_func : callable Function that takes parameters and frequencies and returns impedance p0 : array_like Initial parameter values freq : array_like Frequency values impedance_data : ImpedanceData Object containing experimental impedance data lb : array_like Lower bounds for parameters ub : array_like Upper bounds for parameters param_info : list List of dictionaries containing parameter information weighting : str or array_like, optional Weighting scheme ('unit', 'proportional', 'modulus') or custom weights (default: 'modulus') """
[docs] def __init__( self, model_func, p0, freq, impedance_data: ImpedanceData, lb, ub, param_info, weighting="modulus", ): """ Initialize ECM fitter with model and data. Parameters ---------- model_func : callable Function that takes parameters and frequencies and returns impedance p0 : array_like Initial parameter values freq : array_like Frequency values impedance_data : ImpedanceData Object containing experimental impedance data lb : array_like Lower bounds for parameters ub : array_like Upper bounds for parameters param_info : list List of dictionaries containing parameter information weighting : str or array_like, optional Weighting scheme ('unit', 'proportional', 'modulus') or custom weights (default: 'modulus') """ self.logger = logging.getLogger(__name__) jax.config.update("jax_enable_x64", True) # Validate bounds if any(lb > ub): raise ValueError("Lower bounds must be less than upper bounds") # Store data and setup self.impedance_data = impedance_data self.freq = freq self.data = jnp.array(impedance_data.real) + 1j * jnp.array( impedance_data.imaginary ) self.model = jax.jit(model_func) self.p0 = p0 self.lb = lb self.ub = ub self.num_params = len(p0) self.num_freq = len(freq) self.dof = 2 * self.num_freq - self.num_params self.param_info = param_info # Set up weighting self._setup_weighting(weighting)
[docs] def compute_normalized_residuals( self, Z_fit: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute normalized residuals using impedance modulus for normalization. .. math:: r_{real} = \\frac{Z_{fit,real} - Z_{exp,real}}{\\sqrt{Z_{exp,real}^2 + Z_{exp,imag}^2}} r_{imag} = \\frac{Z_{fit,imag} - Z_{exp,imag}}{\\sqrt{Z_{exp,real}^2 + Z_{exp,imag}^2}} Parameters ---------- Z_fit : np.ndarray Complex array containing the fitted impedance values Returns ------- Tuple[np.ndarray, np.ndarray] Tuple containing: - residuals_real: Normalized residuals of the real component - residuals_imag: Normalized residuals of the imaginary component """ # Calculate impedance modulus |Z| for normalization Z_mod = np.sqrt(self.impedance_data.real**2 + self.impedance_data.imaginary**2) # Compute normalized residuals using modulus residuals_real = (Z_fit.real - self.impedance_data.real) / Z_mod residuals_imag = (Z_fit.imag - self.impedance_data.imaginary) / Z_mod return residuals_real, residuals_imag
def _setup_weighting(self, weighting): """ Configure weighting scheme for the fitting process. Parameters ---------- weighting : str or array_like Weighting scheme ('unit', 'proportional', 'modulus') or custom weights .. math:: \\sigma_i = \\begin{cases} 1 & \\text{for unit weighting} \\\\ |Z_{exp,i}| & \\text{for proportional weighting} \\\\ \\sqrt{(Re(Z_{exp,i}))^2 + (Im(Z_{exp,i}))^2} & \\text{for modulus weighting} \\end{cases} """ if isinstance(weighting, (jnp.ndarray, np.ndarray)): self.logger.debug("Using custom sigma weighting array") self.weighting_name = "sigma" weighting = jnp.array(weighting) assert ( self.data.shape == weighting.shape ), "Shape mismatch between data and weight array" self.zerr_Re = weighting self.zerr_Im = weighting elif isinstance(weighting, str): assert weighting.lower() in [ "unit", "proportional", "modulus", ], f"Invalid weighting type: {weighting}" self.weighting_name = weighting.lower() self.logger.info(f"Using {self.weighting_name} weighting") if weighting.lower() == "unit": self.zerr_Re = jnp.ones(self.num_freq) self.zerr_Im = jnp.ones(self.num_freq) elif weighting.lower() == "proportional": self.zerr_Re = jnp.abs(self.data.real) self.zerr_Im = jnp.abs(self.data.imag) else: # modulus weighting self.zerr_Re = jnp.abs(self.data) self.zerr_Im = jnp.abs(self.data)
[docs] def encode(self, p: jnp.ndarray) -> jnp.ndarray: """ Convert external parameters to internal parameters using log-transform. .. math:: p_{int} = \\log_{10}\\left(\\frac{p - lb}{1 - p/ub}\\right) Parameters ---------- p : jnp.ndarray External parameters Returns ------- jnp.ndarray Internal parameters """ return jnp.log10((p - self.lb) / (1 - p / self.ub))
[docs] def decode(self, p: jnp.ndarray) -> jnp.ndarray: """ Convert internal parameters to external parameters. .. math:: p_{ext} = \\frac{lb + 10^p}{1 + 10^p/ub} Parameters ---------- p : jnp.ndarray Internal parameters Returns ------- jnp.ndarray External parameters """ return (self.lb + 10**p) / (1 + 10**p / self.ub)
[docs] def obj_fun(self, p_log: jnp.ndarray) -> float: """ Calculate weighted objective function. .. math:: WRSS = \\sum_{i=1}^N \\frac{(Z_{exp,i} - Z_{model,i})^2}{\\sigma_i^2} Parameters ---------- p_log : jnp.ndarray Parameters in log space Returns ------- float Weighted residual sum of squares """ p_norm = self.decode(p_log) z_concat = jnp.concatenate([self.data.real, self.data.imag]) sigma = jnp.concatenate([self.zerr_Re, self.zerr_Im]) z_model = self.model(p_norm, self.freq) wrss = jnp.sum((1 / sigma**2) * (z_concat - z_model) ** 2) return wrss
[docs] def compute_aic(self, wrss: float) -> float: """ Compute Akaike Information Criterion (AIC) for model selection. For unit weighting: .. math:: \\mathrm{AIC} = 2N\\ln(2\\pi) - 2N\\ln(2N) + 2N + 2N\\ln(\\mathrm{WRSS}) + 2k For modulus/proportional weighting: .. math:: \\mathrm{AIC} = 2N\\ln(2\\pi) - 2N\\ln(2N) + 2N - \\sum\\ln(w_i) + 2N\\ln(\\mathrm{WRSS}) + 2(k+1) For sigma weighting: .. math:: \\mathrm{AIC} = 2N\\ln(2\\pi) + \\sum\\ln(\\sigma_i^2) + \\mathrm{WRSS} + 2k Parameters ---------- wrss : float Weighted residual sum of squares Returns ------- float Computed AIC value """ wt_re = 1 / self.zerr_Re**2 wt_im = 1 / self.zerr_Im**2 if self.weighting_name == "sigma": m2lnL = ( (2 * self.num_freq) * jnp.log(2 * jnp.pi) + jnp.sum(jnp.log(self.zerr_Re**2)) + jnp.sum(jnp.log(self.zerr_Im**2)) + wrss ) return m2lnL + 2 * self.num_params elif self.weighting_name == "unit": m2lnL = ( 2 * self.num_freq * jnp.log(2 * jnp.pi) - 2 * self.num_freq * jnp.log(2 * self.num_freq) + 2 * self.num_freq + 2 * self.num_freq * jnp.log(wrss) ) return m2lnL + 2 * self.num_params else: m2lnL = ( 2 * self.num_freq * jnp.log(2 * jnp.pi) - 2 * self.num_freq * jnp.log(2 * self.num_freq) + 2 * self.num_freq - jnp.sum(jnp.log(wt_re)) - jnp.sum(jnp.log(wt_im)) + 2 * self.num_freq * jnp.log(wrss) ) return m2lnL + 2 * (self.num_params + 1)
def _calculate_uncertainties(self, popt: jnp.ndarray, wrms: float) -> jnp.ndarray: """ Calculate parameter uncertainties using QR decomposition of the Jacobian. .. math:: \\sigma_j = \\|R^{-1}_j\\| \\sqrt{WRMS} Where: - R is the upper triangular matrix from QR decomposition - WRMS is the weighted root mean square error Parameters ---------- popt : jnp.ndarray Optimal parameters wrms : float Weighted root mean square error Returns ------- jnp.ndarray Array of parameter uncertainties """ grads = jax.jacfwd(self.model)(popt, self.freq) grads_re = grads[: self.num_freq] grads_im = grads[self.num_freq :] rtwre = jnp.diag(1 / self.zerr_Re) rtwim = jnp.diag(1 / self.zerr_Im) vre = rtwre @ grads_re vim = rtwim @ grads_im Q1, R1 = jnp.linalg.qr(jnp.concatenate([vre, vim], axis=0)) invR1 = jnp.linalg.inv(R1) return jnp.linalg.norm(invR1, axis=1) * jnp.sqrt(wrms) def _calculate_correlation_matrix(self, popt: jnp.ndarray) -> jnp.ndarray: """ Calculate correlation matrix using Hessian of objective function. .. math:: C_{ij} = \\frac{H^{-1}_{ij}}{\\sqrt{H^{-1}_{ii}H^{-1}_{jj}}} Where H is the Hessian matrix at the optimal parameters. Parameters ---------- popt : jnp.ndarray Optimal parameters Returns ------- jnp.ndarray Correlation matrix """ # Get Hessian at optimal parameters hessian = jax.hessian(self.obj_fun)(self.encode(popt)) # Use SVD for numerical stability U, s, Vt = jnp.linalg.svd(hessian, full_matrices=False) # Filter small singular values rcond = jnp.finfo(s.dtype).eps * max(hessian.shape) cutoff = rcond * s[0] s_inv = jnp.where(s > cutoff, 1 / s, 0) # Compute covariance matrix as inverse of Hessian cov = (Vt.T * s_inv) @ Vt # Calculate correlation matrix std = jnp.sqrt(jnp.diag(cov)) corr = cov / (std[:, None] @ std[None, :]) return corr
[docs] def calculate_fitted_impedance(self, parameters: jnp.ndarray) -> jnp.ndarray: """ Calculate fitted impedance values from parameters. Parameters ---------- parameters : jnp.ndarray Model parameters Returns ------- jnp.ndarray Complex array of fitted impedance values """ z_model = self.model(parameters, self.freq) return z_model[: self.num_freq] + 1j * z_model[self.num_freq :]
[docs] def compute_fit_quality_metrics(self, Z_fit: np.ndarray) -> FitQualityMetrics: """ Compute fit quality using vector difference and path deviation metrics. The vector difference analysis quantifies the point-by-point agreement: .. math:: \\mathrm{VD} = \\frac{1}{N}\\sum_{i=1}^N \\frac{|Z_{\\mathrm{fit},i} - Z_{\\mathrm{exp},i}|}{|Z_{\\mathrm{exp},i}|} The path deviation analysis quantifies trajectory agreement: .. math:: \\mathrm{PD} = \\frac{1}{N-1}\\sum_{i=1}^{N-1} \\left|\\frac{\\Delta Z_{\\mathrm{fit},i}}{|\\Delta Z_{\\mathrm{fit},i}|} - \\frac{\\Delta Z_{\\mathrm{exp},i}}{|\\Delta Z_{\\mathrm{exp},i}|}\\right| Parameters ---------- Z_fit : np.ndarray Complex fitted impedance array Returns ------- FitQualityMetrics Computed quality metrics including vector difference, path deviation, and overall quality assessment """ Z_exp = self.impedance_data.real + 1j * self.impedance_data.imaginary # 1. Vector Difference Analysis vector_diff = np.mean(np.abs(Z_fit - Z_exp) / np.abs(Z_exp)) # Assign vector quality if vector_diff < 0.05: # 5% average deviation vector_quality = "excellent" elif vector_diff < 0.10: # 10% average deviation vector_quality = "acceptable" else: vector_quality = "poor" # 2. Path Deviation Analysis dZ_exp = np.diff(Z_exp) dZ_fit = np.diff(Z_fit) # Normalize vectors to unit length and compare directions path_diff = np.mean(np.abs(dZ_fit / np.abs(dZ_fit) - dZ_exp / np.abs(dZ_exp))) # Assign path quality if path_diff < 0.05: # 5% average path deviation path_quality = "excellent" elif path_diff < 0.10: # 10% average path deviation path_quality = "acceptable" else: path_quality = "poor" # Overall quality assessment if vector_quality == "excellent" and path_quality == "excellent": overall_quality = "excellent" elif vector_quality == "poor" or path_quality == "poor": overall_quality = "poor" else: overall_quality = "acceptable" return FitQualityMetrics( vector_difference=float(vector_diff), vector_quality=vector_quality, path_deviation=float(path_diff), path_quality=path_quality, overall_quality=overall_quality, )
[docs] def fit(self) -> Optional[FitResult]: """ Perform impedance fitting on the data. The fitting process involves: 1. Parameter optimization using BFGS algorithm 2. Uncertainty calculation via QR decomposition 3. Computation of fit quality metrics 4. Calculation of correlation matrix 5. Model selection metrics (AIC) Returns ------- Optional[FitResult] Complete fitting results including: - Optimized parameters and uncertainties - Correlation matrix - Goodness-of-fit metrics (χ², AIC, WRMS) - Fitted impedance values - Fit quality assessment Returns None if fitting fails """ try: self.logger.info("Starting ECM fitting") # Convert initial parameters to log scale p_log = self.encode(self.p0) self.logger.debug(f"Initial parameters (log scale): {p_log}") # Optimize using BFGS solver = jaxopt.ScipyMinimize(method="BFGS", fun=jax.jit(self.obj_fun)) sol = solver.run(p_log) # Get optimized parameters popt_log = sol.params popt = self.decode(popt_log) wrss = sol.state.fun_val wrms = wrss / self.dof self.logger.info(f"Optimization complete: WRMS = {wrms:.6e}") param_info_str = "\n".join( [f"{p['name']}: {val:.6e}" for p, val in zip(self.param_info, popt)] ) self.logger.debug(f"Optimal parameters:\n{param_info_str}") # Calculate uncertainties perr = self._calculate_uncertainties(popt, wrms) # Calculate fitted impedance values Z_fit = self.calculate_fitted_impedance(popt) Z_fit = np.array(Z_fit) # Calculate normalized residuals residuals_real, residuals_imag = self.compute_normalized_residuals(Z_fit) # Compute AIC aic = self.compute_aic(wrss) # Compute fit quality metrics fit_quality_metrics = self.compute_fit_quality_metrics(Z_fit) # Calculate correlation matrix correlation_matrix = np.array(self._calculate_correlation_matrix(popt)) # Create result object with all metrics result = FitResult( parameters=popt.tolist(), errors=perr.tolist(), param_info=self.param_info, correlation_matrix=correlation_matrix, chi_square=float(wrss), aic=float(aic), wrms=float(wrms), dof=self.dof, Z_fit=Z_fit, fit_quality=fit_quality_metrics, ) self.logger.info(f"Fit metrics: χ² = {wrss:.6e}, AIC = {aic:.6e}") self.logger.debug( "Parameter uncertainties:\n" + "\n".join( [f"{p['name']}: {err:.6e}" for p, err in zip(self.param_info, perr)] ) ) return result except Exception: self.logger.error("ECM fitting failed", exc_info=True) return None