Source code for pyshred.models.decoder_models.abstract_decoder
from abc import ABC, abstractmethod
import torch.nn as nn
[docs]
class AbstractDecoder(ABC, nn.Module):
"""
Abstract base class for all decoder models.
"""
def __init__(self):
"""
Lazily initialize the decoder model because
`input_size`is typically provided after model initialization.
"""
super().__init__() # initialize nn.Module
self.is_initialized = False # lazy initialization flag
[docs]
@abstractmethod
def initialize(self, input_size):
"""
Initialize the decoder model with input and output sizes.
Parameters:
-----------
input_size : int
Size of the input features.
output_size : int
Size of the output features.
"""
self.input_size = input_size
self.is_initialized = True
[docs]
@abstractmethod
def forward(self, x):
"""
Forward pass through the decoder model.
Parameters:
-----------
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Returns:
--------
torch.Tensor
Output tensor of shape (batch_size, output_size).
"""
if not self.is_initialized:
raise RuntimeError("The decoder model is not initialized. Call `initialize` first.")
pass
@property
@abstractmethod
def model_name(self):
"""
Returns the name of the decoder model.
Returns:
--------
str
The name of the model.
"""
pass