Source code for partial_tagger.encoders.base

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

from torch.nn import Module

if TYPE_CHECKING:
    import torch
    from sequence_label import LabelSet


[docs]class BaseEncoder(Module, metaclass=ABCMeta): """Base class for all encoders."""
[docs] @abstractmethod def forward(self, inputs: dict[str, torch.Tensor]) -> torch.Tensor: """Encodes the given inputs to a tensor representation. Args: inputs: A dictionary that maps string keys to a tensor values. Returns: A [batch_size, sequence_length, hidden_size] float tensor. """ raise NotImplementedError
[docs] @abstractmethod def get_hidden_size(self) -> int: """Returns the dimension size of the output tensor. Returns: The dimension size of the output tensor. """ raise NotImplementedError
[docs]class BaseEncoderFactory(metaclass=ABCMeta): """Base class for all encoder factories."""
[docs] @abstractmethod def create(self, label_set: LabelSet) -> BaseEncoder: """Creates an encoder based on the provided label set. Args: label_set: An instance of LabelSet. Returns: An encoder that transforms input into a tensor representation. """ raise NotImplementedError