orcalib.torch_layers.gather_top_k#
GatherTopK
#
Bases: Module
Parameters:
-
k
(int
) –number of top elements to select
forward
#
Select the top memories based on the weights and return their properties
Parameters:
-
weights
(Tensor
) –weights to sort selection by, float tensor of shape batch_size x num_total
-
other_props
(Tensor
, default:()
) –other properties to select with shape batch_size x num_total (x optional_dim)
Returns:
-
tuple[Tensor, ...]
–tuple of properties with the top elements selected, always including the weights as the first element, shape batch_size x num_top (x optional_dim)
Examples: