Source code for multimodal_fin.embeddings.builder.pipeline

import torch
import logging
from anytree import PreOrderIter

from multimodal_fin.embeddings.builder.feature_extractor import FeatureExtractor
from multimodal_fin.embeddings.builder.node_encoder import NodeEncoder
from multimodal_fin.embeddings.builder.conference_encoder import ConferenceEncoder

from multimodal_fin.embeddings.speech_tree.conference_tree_builder import ConferenceTreeBuilder

from multimodal_fin.embeddings.visualizer.conference_tree_visualizer import ConferenceTreeVisualizer
from multimodal_fin.embeddings.visualizer.tree_attention_visualizer import TreeAttentionVisualizer
from multimodal_fin.embeddings.visualizer.node_embeddings_visualizer import NodeEmbeddingVisualizer

logger = logging.getLogger(__name__)


[docs] class ConferenceEmbeddingPipeline: """Orchestrates the generation and visualization of conference-level embeddings.""" def __init__( self, node_encoder_params: dict, conference_encoder_params: dict, device: str = "cpu" ): """Initializes the pipeline with encoders and feature extractor. Args: node_encoder_params (dict): Parameters for the NodeEncoder. conference_encoder_params (dict): Parameters for the ConferenceEncoder. device (str): Torch device to use ("cpu" or "cuda"). """ self.device = torch.device(device) self.node_encoder = NodeEncoder(self.device, **node_encoder_params).to(self.device) self.conference_encoder = ConferenceEncoder(self.device, **conference_encoder_params).to(self.device) self.extractor = FeatureExtractor( categories_10k=self.node_encoder.categories_10k, qa_categories=self.node_encoder.qa_categories, max_num_coherences=self.node_encoder.max_num_coherences ) logger.info("ConferenceEmbeddingPipeline initialized.")
[docs] def generate_embedding(self, json_path: str, return_attn: bool = False) -> torch.Tensor: """Generates the embedding for a given conference JSON. Args: json_path (str): Path to the JSON file describing the conference. return_attn (bool): Whether to return attention weights. Returns: torch.Tensor: Embedding vector for the full conference. """ builder = ConferenceTreeBuilder(json_path) self.root = builder.build_tree() self._node_embeddings = [] self._node_names = [] self._node_types = [] self._categories_10k = [] for node in PreOrderIter(self.root): if node.is_leaf and node.node_type in {"monologue", "question", "answer"}: frases, mask, meta_vec = self.extractor.extract(node) frase_summary = self.node_encoder.frase_encoder(frases.to(self.device), mask.to(self.device)) meta_tensor = torch.tensor(meta_vec, dtype=torch.float32, device=self.device).unsqueeze(0) meta_summary = self.node_encoder.meta_proj(meta_tensor) combined = torch.cat([frase_summary, meta_summary], dim=-1) node_embedding = self.node_encoder.output_proj(combined).squeeze(0) self._node_embeddings.append(node_embedding) self._node_names.append(node.name) self._node_types.append(node.node_type) predicted_cat = node.metadata.get("classification", {}).get("Predicted_category", "None") self._categories_10k.append(predicted_cat if node.node_type != "monologue" else "None") if not self._node_embeddings: logger.warning("No valid leaf nodes found for embedding computation.") return torch.zeros(self.node_encoder.d_output) stacked = torch.stack(self._node_embeddings, dim=0) if return_attn: conference_embedding, attn_weights = self.conference_encoder(stacked, return_attn=True) self._attn_weights = attn_weights return conference_embedding return self.conference_encoder(stacked)
[docs] def visualize(self, plots: dict = None): """Visualizes the results of the embedding process depending on selected plots. Args: plots (dict): Flags for which plots to generate. """ plots = plots or {} visualizer = ConferenceTreeVisualizer(self.root) if plots.get("tree_structure"): visualizer.show_text_tree() if plots.get("plot"): visualizer.show_networkx_tree() if any(plots.get(k, False) for k in ("silhouette", "umap")): embedding_visualizer = NodeEmbeddingVisualizer( embeddings=self._node_embeddings, node_names=self._node_names, node_types=self._node_types, categories_10k=self._categories_10k ) if plots.get("silhouette"): embedding_visualizer.show_metrics() if plots.get("umap"): embedding_visualizer.show_umap() if plots.get("attention_tree") and hasattr(self, "_attn_weights"): attention_viz = TreeAttentionVisualizer( root=self.root, node_names=self._node_names, attn_weights=self._attn_weights ) attention_viz.show()