from dataclasses import dataclass
from typing import Literal, Tuple, List
import pandas as pd
import json
from pydantic import BaseModel
from multimodal_fin.processing.basics import LLMClient, UncertaintyMixin
from multimodal_fin.processing.metadata.prompt_builder import PromptBuilder
[docs]
class Category10K(BaseModel):
"""Pydantic schema used for validating LLM output when classifying SEC 10-K topics."""
category: Literal['Business', 'Risk Factors', 'MD&A', 'Other']
[docs]
@dataclass
class SEC10KAnalyzer(UncertaintyMixin):
"""Class responsible for classifying intervention text into 10-K categories using LLMs."""
model: str = "llama3"
"""The name of the LLM model to use."""
NUM_EVALUATIONS: int = 10
"""Number of times the classification is repeated to estimate uncertainty."""
def __post_init__(self):
self.llm = LLMClient(self.model)
[docs]
def classify_text(self, text: str) -> str:
"""Classifies a given text into one of the 10-K categories.
Args:
text (str): Text to classify.
Returns:
str: One of ['Business', 'Risk Factors', 'MD&A', 'Other'].
"""
messages = PromptBuilder.prompt_10k(text)
response = self.llm.chat(messages, schema=Category10K.model_json_schema())
return json.loads(response)['category']
[docs]
def explain_other_category(self, text: str) -> str:
"""Provides a natural language explanation for why a text was classified as 'Other'.
Args:
text (str): The text classified as 'Other'.
Returns:
str: Explanation generated by the LLM.
"""
messages = PromptBuilder.explain_why_other(text)
return self.llm.chat(messages)
[docs]
def get_pred(self, text: str) -> Tuple[str, float, List[str]]:
"""Predicts the category for a text using repeated sampling for uncertainty estimation.
Args:
text (str): Text to classify.
Returns:
Tuple[str, float, List[str]]: Most likely category, confidence score, and list of predictions.
"""
predictions = [self.classify_text(text) for _ in range(self.NUM_EVALUATIONS)]
return self.get_result_and_uncertainty(
lambda _: predictions.pop(0),
text,
self.NUM_EVALUATIONS
)
[docs]
def classify_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""Classifies an entire DataFrame of interventions by applying `get_pred` on each row.
Args:
df (pd.DataFrame): DataFrame with a 'text' column.
Returns:
pd.DataFrame: The original DataFrame with an added 'classification' column.
"""
df = df.copy()
df['classification'] = df['text'].apply(lambda t: self.get_pred(t)[0])
return df