Source code for partial_tagger.encoders.base
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING
from torch.nn import Module
if TYPE_CHECKING:
import torch
from sequence_label import LabelSet
[docs]class BaseEncoder(Module, metaclass=ABCMeta):
"""Base class for all encoders."""
[docs] @abstractmethod
def forward(self, inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Encodes the given inputs to a tensor representation.
Args:
inputs: A dictionary that maps string keys to a tensor values.
Returns:
A [batch_size, sequence_length, hidden_size] float tensor.
"""
raise NotImplementedError
[docs] @abstractmethod
def get_hidden_size(self) -> int:
"""Returns the dimension size of the output tensor.
Returns:
The dimension size of the output tensor.
"""
raise NotImplementedError
[docs]class BaseEncoderFactory(metaclass=ABCMeta):
"""Base class for all encoder factories."""
[docs] @abstractmethod
def create(self, label_set: LabelSet) -> BaseEncoder:
"""Creates an encoder based on the provided label set.
Args:
label_set: An instance of LabelSet.
Returns:
An encoder that transforms input into a tensor representation.
"""
raise NotImplementedError