Source code for multimodal_fin.processing.preprocessing.qa_classifier

from dataclasses import dataclass
from typing import Literal, Tuple

import json
import pandas as pd
from pydantic import BaseModel

from multimodal_fin.processing.basics import LLMClient, UncertaintyMixin
from multimodal_fin.processing.metadata.prompt_builder import PromptBuilder

from multimodal_fin.utils.logging import get_logger

logger = get_logger(__name__)


[docs] class CategoryQA(BaseModel): """Pydantic schema for Q&A classification output.""" category: Literal['Question', 'Answer', 'Procedure']
[docs] @dataclass class QAClassifier(UncertaintyMixin): """ Classifier for identifying Q&A intervention types using an LLM with uncertainty estimation. """ model: str = "llama3" """LLM model name.""" NUM_EVALUATIONS: int = 5 """Number of times to sample the model for uncertainty estimation.""" model: str = 'llama3' NUM_EVALUATIONS: int = 5 def __post_init__(self) -> None: """Initializes the LLM client.""" self.llm = LLMClient(self.model) logger.info(f"Initialized QAClassifier with model: {self.model}")
[docs] def classify_text(self, text: str) -> str: """Classifies a single intervention as 'Question', 'Answer' or 'Procedure'. Args: text: The input intervention text. Returns: A string category predicted by the LLM. """ logger.debug(f"Classifying QA text: {text[:80]}...") messages = PromptBuilder.prompt_qa(text) response = self.llm.chat(messages, schema=CategoryQA.model_json_schema()) try: category = json.loads(response)['category'] logger.debug(f"Predicted category: {category}") return category except (json.JSONDecodeError, KeyError) as e: logger.error(f"Failed to parse response: {response}", exc_info=True) raise ValueError(f"Invalid response from LLM: {response}") from e
[docs] def get_pred(self, text: str) -> Tuple[str, float]: """Performs multiple classifications and computes uncertainty score. Args: text: Input intervention text. Returns: A tuple of (predicted_category, confidence_score). """ logger.info("Running ensemble prediction for QA text...") predicted_categories = [self.classify_text(text) for _ in range(self.NUM_EVALUATIONS)] return self.get_result_and_uncertainty( predict_fn=lambda _: predicted_categories.pop(0), text=text, n=self.NUM_EVALUATIONS )
[docs] def classify_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: """Classifies all interventions in a DataFrame. Args: df: DataFrame containing a 'text' column. Returns: DataFrame with an added 'classification' column. """ logger.info("Classifying QA dataframe...") df['classification'] = df['text'].apply(lambda text: self.get_pred(text)[0]) return df