orcalib.rac.head_models#
RACHeadProtocol
#
Bases: Protocol
forward
#
Forward pass of the model. Args: input_embeddings (Tensor): The input embeddings. memory_embeddings (list[Tensor]): The embeddings of the memories. memory_labels (list[list[int]]): The labels of the memories. original_input (list[InputType]): The original (non-embedded) input examples. original_memories (list[list[InputType]]): The original (non-embedded) memories. Returns: Tensor: Result Logits.