Source code for pytagi.nn.attention
import cutagi
from pytagi.nn.base_layer import BaseLayer
[docs]
class MultiheadAttention(BaseLayer):
"""
Implements a **Multi-head Attention layer** with uncertainty quantification.
This layer applies scaled dot-product attention with multiple attention heads,
allowing the model to jointly attend to information from different representation
subspaces. It inherits from BaseLayer.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int = None,
bias: bool = True,
gain_weight: float = 1.0,
gain_bias: float = 1.0,
init_method: str = "Xavier",
):
"""
Initializes the MultiheadAttention layer.
Args:
embed_dim: The dimensionality of the input embeddings and output.
num_heads: The number of attention heads.
num_kv_heads: The number of key-value heads for grouped-query attention.
If None, defaults to num_heads (standard multi-head attention).
bias: If True, additive bias is included in the linear projections.
Defaults to True.
gain_weight: Scaling factor applied to initialized weights. Defaults to 1.0.
gain_bias: Scaling factor applied to initialized biases. Defaults to 1.0.
init_method: The method used for initializing weights and biases
(e.g., "Xavier", "He"). Defaults to "Xavier".
"""
super().__init__()
if num_kv_heads is None:
num_kv_heads = num_heads
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.bias = bias
self.gain_weight = gain_weight
self.gain_bias = gain_bias
self.init_method = init_method
self._cpp_backend = cutagi.MultiheadAttention(
embed_dim,
num_heads,
num_kv_heads,
bias,
gain_weight,
gain_bias,
init_method,
)
[docs]
def get_layer_info(self) -> str:
"""
Retrieves a descriptive string containing information about the layer's
configuration from the C++ backend.
"""
return self._cpp_backend.get_layer_info()
[docs]
def get_layer_name(self) -> str:
"""
Retrieves the name of the layer from the C++ backend.
"""
return self._cpp_backend.get_layer_name()
[docs]
def init_weight_bias(self):
"""
Initializes the layer's parameters for query, key, and value projections
using the specified initialization method and gain factors.
This task is delegated to the C++ backend.
"""
self._cpp_backend.init_weight_bias()