Skip to content

orcalib.torch_layers.gather_top_k#

GatherTopK #

GatherTopK(k)

Bases: Module

Parameters:

  • k (int) –

    number of top elements to select

last_indices instance-attribute #

last_indices = None

Indices of the last top k elements selected

forward #

forward(weights, *other_props)

Select the top memories based on the weights and return their properties

Parameters:

  • weights (Tensor) –

    weights to sort selection by, float tensor of shape batch_size x num_total

  • other_props (Tensor, default: () ) –

    other properties to select with shape batch_size x num_total (x optional_dim)

Returns:

  • tuple[Tensor, ...]

    tuple of properties with the top elements selected, always including the weights as the first element, shape batch_size x num_top (x optional_dim)

Examples:

>>> selector = GatherTopK(2)
>>> selector(
...     torch.tensor([[0.1, 0.2, 0.3], [0.3, 0.2, 0.1]]),
...     torch.tensor([[1, 2, 3], [3, 2, 1]]),
...     torch.tensor([[4, 5, 6], [6, 5, 4]]),
... )
(tensor([[0.2, 0.3], [0.3, 0.2]]), tensor([[2, 3], [3, 2]]), tensor([[5, 6], [6, 5]]))