Source code for multimodal_fin.processing.preprocessing.monologue_classifier

from dataclasses import dataclass
from typing import Literal, Tuple

import json
import pandas as pd
from pydantic import BaseModel

from multimodal_fin.processing.basics import LLMClient, UncertaintyMixin
from multimodal_fin.processing.metadata.prompt_builder import PromptBuilder

import logging
logger = logging.getLogger(__name__)


[docs] class CategoryPresentation(BaseModel): """Pydantic schema for monologue classification output.""" category: Literal['Monologue', 'Procedure']
[docs] @dataclass class MonologueClassifier(UncertaintyMixin): """Classifier for identifying monologue categories using an LLM and uncertainty estimation.""" model: str = 'llama3' NUM_EVALUATIONS: int = 5 output_path: str = 'output' def __post_init__(self) -> None: """Initializes the LLM client with the specified model.""" self.llm = LLMClient(self.model) logger.info(f"Initialized MonologueClassifier with model: {self.model}")
[docs] def classify_text(self, text: str) -> str: """Classifies a single text as either 'Monologue' or 'Procedure'. Args: text: Input string representing a conference intervention. Returns: Category string: 'Monologue' or 'Procedure'. """ logger.debug(f"Classifying text: {text[:80]}...") messages = PromptBuilder.prompt_monologue(text) response = self.llm.chat(messages, schema=CategoryPresentation.model_json_schema()) try: category = json.loads(response)['category'] logger.debug(f"Predicted category: {category}") return category except (json.JSONDecodeError, KeyError) as e: logger.error(f"Failed to parse response: {response}", exc_info=True) raise ValueError(f"Invalid response from LLM: {response}") from e
[docs] def get_pred(self, text: str) -> Tuple[str, float]: """Performs multiple evaluations and computes uncertainty score. Args: text: Input intervention text. Returns: A tuple of (predicted_category, confidence_score). """ logger.info("Running ensemble prediction...") predicted_categories = [self.classify_text(text) for _ in range(self.NUM_EVALUATIONS)] return self.get_result_and_uncertainty( predict_fn=lambda _: predicted_categories.pop(0), text=text, n=self.NUM_EVALUATIONS )
[docs] def classify_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: """Classifies all interventions in a DataFrame. Args: df: DataFrame containing a 'text' column with interventions. Returns: DataFrame with an added 'classification' column. """ logger.info("Classifying monologue dataframe...") df['classification'] = df['text'].apply(lambda text: self.get_pred(text)[0]) return df