Source code for multimodal_fin.embeddings.builder.conference_encoder

import torch
import torch.nn as nn
from typing import Optional, Tuple

from multimodal_fin.embeddings.builder.transformer_encoder import TransformerEncoderLayer
from multimodal_fin.utils.logging import get_logger

logger = get_logger(__name__)


[docs] class ConferenceEncoder(nn.Module): """Encoder that aggregates node-level embeddings into a single conference-level embedding using a Transformer encoder with a [CLS] token and learned positional encodings.""" def __init__( self, device: str = "cpu", input_dim: int = 512, hidden_dim: int = 256, n_heads: int = 4, d_output: int = 512, max_nodes: int = 1000, weights_path: Optional[str] = None, ): """ Args: device: Device to run the model on. input_dim: Dimension of input node embeddings. hidden_dim: Hidden dimension of the Transformer. n_heads: Number of attention heads. d_output: Dimension of the output conference embedding. max_nodes: Max number of nodes to consider in a conference. weights_path: Optional path to a pretrained model checkpoint. """ super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, input_dim)) self.pos_embedding = nn.Embedding(max_nodes + 1, input_dim) self.encoder_layer = TransformerEncoderLayer( d_model=input_dim, nhead=n_heads, dim_feedforward=hidden_dim * 2 ) self.proj = nn.Linear(input_dim, d_output) if weights_path: try: state_dict = torch.load(weights_path, map_location=device) # Remove incompatible positional embeddings if 'pos_embedding.weight' in state_dict: del state_dict['pos_embedding.weight'] # logger.warning("⚠️ Skipped 'pos_embedding.weight' due to size mismatch.") self.load_state_dict(state_dict, strict=False) logger.info(f"✅ Weights loaded from {weights_path}") except Exception as e: logger.error(f"❌ Failed to load weights from {weights_path}: {e}") else: logger.warning("⚠️ No pretrained weights path provided for ConferenceEncoder")
[docs] def forward(self, node_embeddings: torch.Tensor, return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: node_embeddings: Tensor of shape [n_nodes, input_dim] return_attn: Whether to return attention weights from [CLS] token. Returns: Conference embedding of shape [1, d_output] Optionally, attention weights from [CLS] to all other nodes. """ n_nodes = node_embeddings.size(0) # Insert [CLS] token cls = self.cls_token.expand(1, -1, -1) # [1, 1, input_dim] input_seq = torch.cat([cls, node_embeddings.unsqueeze(0)], dim=1) # [1, n+1, input_dim] # Positional encoding pos_ids = torch.arange(n_nodes + 1, device=input_seq.device).unsqueeze(0) # [1, n+1] pos_emb = self.pos_embedding(pos_ids) input_seq = input_seq + pos_emb # [1, n+1, input_dim] # Transformer out = self.encoder_layer(input_seq) # [1, n+1, input_dim] if return_attn: attn_weights = self.encoder_layer.attn_weights # [1, n_heads, T, T] attn_from_cls = attn_weights[0, 0, 0, 1:].detach().cpu().numpy() # [n_nodes] return self.proj(out[:, 0, :]), attn_from_cls return self.proj(out[:, 0, :])