Source code for partial_tagger.recognizer
from __future__ import annotations
from typing import TYPE_CHECKING
from torch.utils.data import DataLoader
if TYPE_CHECKING:
from collections.abc import Sequence
import torch
from sequence_label import LabelAlignment, LabelSet, SequenceLabel
from partial_tagger.data.collators import BaseCollator, Batch
from partial_tagger.tagger import SequenceTagger
[docs]class Recognizer:
"""A recognizer which predicts character-based tags from a given text with
a trained sequence tagger.
Args:
tagger: An instance of SequenceTagger representing the trained tagger.
collator: Any instance of the classes that inherit BaseCollator for
encoding given texts into tensors.
label_set: An instance of LabelSet.
"""
def __init__(
self,
tagger: SequenceTagger,
collator: BaseCollator,
label_set: LabelSet,
):
self.__tagger = tagger
self.__collator = collator
self.__label_set = label_set
def __call__(
self, texts: tuple[str, ...], batch_size: int, device: torch.device
) -> tuple[SequenceLabel, ...]:
"""Predicts character-based tags from given texts using a trained tagger.
Args:
texts: A tuple of input texts.
batch_size: An integer representing a batch size.
device: The device to use for prediction.
Returns:
A tuple where each item is a set of predicted character-based tags
for each input text.
"""
dataloader: Sequence[tuple[Batch, tuple[LabelAlignment, ...]]] = DataLoader(
texts, # type: ignore
collate_fn=self.__collator, # type:ignore
batch_size=batch_size,
shuffle=False,
)
tagger = self.__tagger.eval().to(device)
predictions: list[SequenceLabel] = []
for batch, alignments in dataloader:
batch = batch.to(device)
tag_indices = tagger.predict(batch.tagger_inputs, batch.mask)
predictions.extend(
self.__label_set.decode(
tag_indices=tag_indices.tolist(), alignments=alignments
)
)
return tuple(predictions)