Source code for impedance_agent.agent.base

# impedance_agent/agent/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from ..core.models import ImpedanceData
from ..agent.tools.fitter_tools import FitterTools
import logging


[docs] class BaseAgent(ABC): """Abstract base class for provider-specific agents"""
[docs] def __init__(self): """Initialize base agent components""" self.logger = logging.getLogger(__name__) self.fitter_tools = FitterTools() self.tools = self._get_tools()
@property @abstractmethod def system_prompt(self) -> str: """Get provider-specific system prompt""" pass
[docs] @abstractmethod def get_user_prompt(self, data: ImpedanceData, model_config: Optional[Dict]) -> str: """ Get provider-specific user prompt Args: data: Impedance data to analyze model_config: Optional ECM configuration Returns: Formatted prompt string """ pass
@abstractmethod def _get_tools(self) -> List[Dict[str, Any]]: """ Get provider-specific tool definitions Returns: List of tool definitions in provider-specific format """ pass @abstractmethod def _format_message(self, message: Dict) -> Dict: """ Format message according to provider requirements Args: message: Message to format Returns: Formatted message for the specific provider """ pass
[docs] @abstractmethod def create_chat_completion( self, messages: List[Dict], tools: List[Dict] = None, tool_choice: str = "auto" ) -> Any: """ Make API call to provider Args: messages: List of conversation messages tools: List of available tools tool_choice: How tools should be selected Returns: Provider-specific response object """ pass
def _handle_tool_call(self, tool_call: Any, data: ImpedanceData) -> Dict[str, Any]: """ Handle tool calls - common across providers Args: tool_call: Tool call object from the LLM data: Impedance data to analyze Returns: Result of the tool execution Raises: ValueError: If tool is not recognized """ tool_name = tool_call.function.name arguments = tool_call.function.arguments if isinstance(arguments, str): try: arguments = eval(arguments) self.logger.debug(f"Parsed arguments: {arguments}") except Exception as e: self.logger.error(f"Error parsing arguments: {str(e)}") self.logger.error(f"Raw arguments: {arguments}") raise try: if tool_name == "fit_ecm": self.logger.debug("Starting ECM fit") result = self.fitter_tools.run_ecm_fit(data, **arguments) self.logger.debug("ECM fit completed") return result elif tool_name == "fit_drt": self.logger.debug("Starting DRT analysis") result = self.fitter_tools.run_drt_fit(data, **arguments) self.logger.debug("DRT analysis completed") return result elif tool_name == "fit_linkk": self.logger.debug("Starting Lin-KK validation") result = self.fitter_tools.run_linkk_fit(data, **arguments) self.logger.debug("Lin-KK validation completed") return result else: error_msg = f"Unknown tool: {tool_name}" self.logger.error(error_msg) raise ValueError(error_msg) except Exception as e: self.logger.error(f"Error in tool execution: {str(e)}") self.logger.exception("Tool execution error details:") raise