Source code for multimodal_fin.embeddings.builder.node_encoder

import logging
import torch
import torch.nn as nn
from multimodal_fin.embeddings.builder.sentence_attention_encoder import SentenceAttentionEncoder
from multimodal_fin.utils.logging import get_logger

logger = get_logger(__name__)

[docs] class NodeEncoder(nn.Module): """ Encodes individual nodes in a conference tree using multimodal features and metadata. Attributes: frase_encoder (nn.Module): Encoder for sentence-level features using attention. meta_proj (nn.Linear): Projection layer for metadata features. output_proj (nn.Linear): Final projection layer to produce node embedding. categories_10k (List[str]): List of 10-K classification categories. qa_categories (List[str]): List of QA response categories. max_num_coherences (int): Maximum number of coherence entries per node. """ def __init__( self, device: str = "cpu", input_dim: int = 21, # 7 (text) + 7 (audio) + 7 (video) hidden_dim: int = 128, meta_dim: int = 32, d_output: int = 512, n_heads: int = 4, categories_10k: list = None, qa_categories: list = None, weights_path: str = "weights/node_encoder.pt" ): super().__init__() self.device = device self.d_output = d_output self.meta_dim = meta_dim self.categories_10k = categories_10k or ["MD&A", "Risk Factors", "Business", "Other"] self.qa_categories = qa_categories or ["yes", "no", "partially"] self.max_num_coherences = 5 self.weights_path = weights_path self.frase_encoder = SentenceAttentionEncoder( input_dim=input_dim, hidden_dim=hidden_dim, n_heads=n_heads ) if weights_path: try: self.frase_encoder.load_state_dict(torch.load(weights_path, map_location=device)) logger.info(f"✅ Weights loaded from {weights_path}") except Exception as e: logger.warning(f"❌ Failed to load weights from {weights_path}: {e}") else: logger.warning("⚠️ No pretrained weights path provided for NodeEncoder") self.frase_encoder.to(device) self.meta_proj = nn.Linear(self._get_meta_input_size(), meta_dim) self.output_proj = nn.Linear(hidden_dim + meta_dim, d_output) def _get_meta_input_size(self) -> int: """Returns the dimensionality of the metadata feature vector.""" return ( 1 + len(self.categories_10k) + 1 + len(self.qa_categories) + 2 * self.max_num_coherences )