Skip to content

orcalib.rac.head_models#

RACHeadProtocol #

RACHeadProtocol(config)

Bases: Protocol

forward #

1
2
3
4
5
6
7
forward(
    input_embeddings,
    memory_embeddings,
    memory_labels,
    original_input,
    original_memories,
)

Forward pass of the model. Args: input_embeddings (Tensor): The input embeddings. memory_embeddings (list[Tensor]): The embeddings of the memories. memory_labels (list[list[int]]): The labels of the memories. original_input (list[InputType]): The original (non-embedded) input examples. original_memories (list[list[InputType]]): The original (non-embedded) memories. Returns: Tensor: Result Logits.