Source code for partial_tagger.tagger

from __future__ import annotations

from typing import TYPE_CHECKING, cast

import torch
from sequence_classifier.crf import BaseCrfDistribution, Crf
from torch.nn import Module, Parameter

if TYPE_CHECKING:
    from partial_tagger.encoders.base import BaseEncoder


[docs]class SequenceTagger(Module): """A sequence tagging model with a CRF layer. Args: encoder: An encoder module. decoder: A decoder module. Attributes: encoder: An encoder module. crf: A CRF layer. decoder: A decoder module. """ def __init__( self, encoder: BaseEncoder, padding_index: int, start_states: tuple[bool, ...] | None = None, end_states: tuple[bool, ...] | None = None, transitions: tuple[tuple[bool, ...], ...] | None = None, ): super().__init__() self.encoder = encoder self.crf = Crf(encoder.get_hidden_size(), padding_index=padding_index) self.start_constraints = ( Parameter(~torch.tensor(start_states), requires_grad=False) if start_states is not None else None ) self.end_constraints = ( Parameter(~torch.tensor(end_states), requires_grad=False) if end_states is not None else None ) self.transition_constraints = ( Parameter(~torch.tensor(transitions), requires_grad=False) if transitions is not None else None )
[docs] def forward( self, inputs: dict[str, torch.Tensor], mask: torch.Tensor, constrain: bool = False, ) -> BaseCrfDistribution: """Computes log potentials and tag sequence. Args: inputs: An inputs representing input data feeding into the encoder module. mask: A [batch_size, sequence_length] boolean tensor. Returns: A pair of a [batch_size, sequence_length, num_tags, num_tags] float tensor and a [batch_size, sequence_length] integer tensor. The float tensor representing log potentials and the integer tensor representing tag sequence. """ if constrain: dist = self.crf( logits=self.encoder(inputs), mask=mask, start_constraints=self.start_constraints, end_constraints=self.end_constraints, transition_constraints=self.transition_constraints, ) else: dist = self.crf(logits=self.encoder(inputs), mask=mask) return cast(BaseCrfDistribution, dist)
def predict( self, inputs: dict[str, torch.Tensor], mask: torch.Tensor ) -> torch.Tensor: dist = self(inputs=inputs, mask=mask, constrain=True) return cast(BaseCrfDistribution, dist).argmax