Source code for multimodal_fin.embeddings.builder.transformer_encoder

import torch.nn.functional as F
from torch import Tensor
from torch.nn import (
    MultiheadAttention,
    LayerNorm,
    Dropout,
    Linear,
    Module
)
from typing import Optional, Tuple

from multimodal_fin.utils.logging import get_logger

logger = get_logger(__name__)


[docs] class TransformerEncoderLayer(Module): """ Custom Transformer encoder layer with self-attention, feedforward network, residual connections and layer normalization. """ def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1 ): """ Initializes the TransformerEncoderLayer. Args: d_model: Input and output dimensionality of the model. nhead: Number of attention heads. dim_feedforward: Dimensionality of the inner feedforward layer. dropout: Dropout rate applied after attention and feedforward layers. """ super().__init__() self.self_attn = MultiheadAttention( embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True ) self.linear1 = Linear(d_model, dim_feedforward) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.activation = F.gelu # Will store attention weights from forward pass self.attn_weights: Optional[Tensor] = None logger.info("✅ TransformerEncoderLayer initialized")
[docs] def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None ) -> Tensor: """ Forward pass of the transformer encoder layer. Args: src: Input tensor of shape [B, T, d_model]. src_mask: Optional attention mask [T, T] or [B * num_heads, T, T]. src_key_padding_mask: Optional mask [B, T] indicating padding positions. Returns: Output tensor of shape [B, T, d_model]. """ attn_output, attn_weights = self.self_attn( src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=True, average_attn_weights=False ) self.attn_weights = attn_weights # [B, n_heads, T, T] src = src + self.dropout1(attn_output) src = self.norm1(src) ff_output = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(ff_output) src = self.norm2(src) return src