Skip to content

orcalib.orca_classification#

OrcaKnnClassifier #

OrcaKnnClassifier(
    num_classes,
    num_memories,
    label_column_name,
    weigh_memories=True,
    database=None,
    curate_enabled=False,
    memory_index_name=None,
    drop_exact_match=DropExactMatchOption.TRAINING_ONLY,
    exact_match_threshold=None,
)

Bases: OrcaLookupModule

A simple KNN layer that returns the average label of the K nearest memories to the input vector.

Examples:

import torch
from orcalib import OrcaModule, OrcaKnnClassifier

class MyModule(OrcaModule):
    def __init__(self):
        super().__init__()
        self.knn_head = OrcaKnnClassifier(
            memory_index_name="my_index",
            label_column_name="my_label",
            num_memories=10,
            num_classes=5,
        )

    def forward(self, x):
        logits = self.knn_head(x)
        return logits

Parameters:

  • num_classes (int) –

    The size of the output vector.

  • num_memories (int) –

    The number of memory vectors to be returned from the lookup.

  • weigh_memories (bool, default: True ) –

    Whether to weigh the memories by their scores.

  • database (OrcaDatabase | str | None, default: None ) –

    The OrcaDatabase instance to use for lookups and curate tracking.

  • curate_enabled (bool, default: False ) –

    Whether Curate tracking is enabled.

  • memory_index_name (str | None, default: None ) –

    The name of the index to use for lookups.

  • drop_exact_match (DropExactMatchOption, default: TRAINING_ONLY ) –

    Choose to drop the exact match (if found) always, never, or only during training or inference.

  • exact_match_threshold (float | None, default: None ) –

    Minimum similarity score for something to be considered the exact match

  • label_column_name (ColumnName) –

    The name of the label column to return from the index.

lookup_result_transforms property writable #

lookup_result_transforms

A list of transforms to apply to the lookup result. NOTE: This will be applied even when lookup_result_override is set.

extra_lookup_column_names property writable #

extra_lookup_column_names

While set, all lookups will include these additional columns. They may inclue columns on the indexed table as well as index-specific columns, e.g., $score, $embedding.

lookup_query_override property writable #

lookup_query_override

The query to use instead of performing a lookup. NOTE: This will be ignored if lookup_result_override is also set.

lookup_result_override property writable #

lookup_result_override

The lookup result to use instead of performing a lookup.

lookup_database property writable #

lookup_database

The name of the database to use for looking up memories.

memory_index_name property writable #

memory_index_name

The name of the index to use for looking up memories.

lookup_column_names property writable #

lookup_column_names

The names of the columns to retrieve for each memory.

num_memories property writable #

num_memories

The number of memories to look up.

drop_exact_match property writable #

drop_exact_match

Whether to drop exact matches from the results.

exact_match_threshold property writable #

exact_match_threshold

The similarity threshold for exact matches.

shuffle_memories property writable #

shuffle_memories

Whether to shuffle the looked up memories.

curate_database property writable #

curate_database

The name of the database to use for saving curate tracking data.

curate_next_run_settings property writable #

curate_next_run_settings

The settings for the next curate model run.

curate_model_id property writable #

curate_model_id

The model id to associate with curated model runs.

curate_model_version property writable #

curate_model_version

The model version to associate with curated model runs.

curate_metadata property writable #

curate_metadata

The metadata to attach to curated model runs.

curate_tags property writable #

curate_tags

The tags to attach to the curated model runs.

curate_seq_id property writable #

curate_seq_id

The sequence id to associate with curated model runs.

curate_batch_size property writable #

curate_batch_size

The batch size of the model run to track curate data for, usually inferred automatically.

last_curate_run_ids property writable #

last_curate_run_ids

The run ids of the last model run for which curate tracking data was collected.

last_curate_run_settings property writable #

last_curate_run_settings

The settings of the last model run for which curate tracking data was collected.

get_effective_lookup_settings #

get_effective_lookup_settings()

Returns the effective lookup settings for this module, with any inherited settings applied.

Returns:

  • LookupSettings

    The effective lookup settings for this module. Practically, this be the lookup settings

  • LookupSettings

    set on this module. For any settings that are not set on this module, the inherited settings

  • LookupSettings

    will be used instead.

get_lookup_database_instance #

get_lookup_database_instance()

Returns the OrcaDatabase instance to use for looking up memories.

get_orca_modules_recursively #

1
2
3
get_orca_modules_recursively(
    max_depth=None, include_self=True, filter_type=None
)

Recursively yields all children of this module that are instances of the specified filter type.

  • All parent nodes will be processed before their children
  • This will search through all children —even those that are not a subclass of the filter_type— but it only returns children that are a subclass of filter_type.

Parameters:

  • max_depth (int | None, default: None ) –

    The maximum depth to search.

    • Setting this to 0 will only include this module.
    • Setting this to 1 will include only this module and its children.
    • Setting it to None (the default) will search through all modules.
    • Modules that are not of filter_type or OrcaModule do not increment the depth.
  • include_self (bool, default: True ) –

    Whether to include the current OrcaModule in the results.

  • filter_type (type[ORCA_MODULE_TYPE] | None, default: None ) –

    The subtype of OrcaModule to filter for. If None, any subtypes of OrcaModule will be returned.

Yields:

  • ORCA_MODULE_TYPE

    modules of type filter_type that are used in the children of this module.

enable_curate #

enable_curate(recursive=True)

Enable Curate tracking for the model and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to enable Curate tracking recursively.

disable_curate #

disable_curate(recursive=True)

Disable Curate tracking for this module and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to disable Curate tracking recursively.

update_curate_settings #

update_curate_settings(
    model_id=None,
    model_version=None,
    tags=None,
    extra_tags=None,
    metadata=None,
    extra_metadata=None,
    batch_size=None,
    seq_id=None,
    enabled=None,
    enable_recursive=True,
)

Update curate tracking settings for the module and all its children.

Parameters:

  • model_id (str | None, default: None ) –

    The ID of the model.

  • model_version (str | None, default: None ) –

    The version of the model.

  • tags (Iterable[str] | None, default: None ) –

    The new tags to be added to the model.

  • extra_tags (Iterable[str] | None, default: None ) –

    The extra tags to be added to the model.

  • metadata (OrcaMetadataDict | None, default: None ) –

    The new metadata to be added to the model.

  • extra_metadata (OrcaMetadataDict | None, default: None ) –

    The extra metadata to be added to the model.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the model.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the model.

record_next_model_memory_lookups #

1
2
3
record_next_model_memory_lookups(
    tags=None, metadata=None, batch_size=None, seq_id=None
)

Sets up curate tracking for the memory lookups during the next forward pass only.

Parameters:

  • tags (Iterable[str] | None, default: None ) –

    Additional tags to be recorded on the next model run.

  • metadata (OrcaMetadataDict | None, default: None ) –

    Additional metadata to be recorded on the next model run.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the next model run.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the next model run.

record_model_feedback #

1
2
3
record_model_feedback(
    val, name="default", kind=FeedbackKind.CONTINUOUS
)

Records feedback for the last model runs for which memory lookups were recorded by curate.

Parameters:

  • val (list[float] | float | int | list[int]) –

    The feedback to be recorded.

  • name (str, default: 'default' ) –

    The name of the feedback.

  • kind (FeedbackKind, default: CONTINUOUS ) –

    The kind of feedback.

record_model_input_output #

record_model_input_output(inputs, outputs)

Records the inputs and outputs of the last model runs for which memory lookups were recorded by curate.

Parameters:

  • inputs (list[Any] | Any) –

    The inputs to be recorded.

  • outputs (list[Any] | Any) –

    The outputs to be recorded.

get_lookup_setting_summary #

get_lookup_setting_summary()

Returns a summary of the lookup settings for each OrcaLookupLayer in this module and its descendants.

forward #

forward(x=None, ctx_labels=None, ctx_scores=None)

Generate logits based on the nearest neighbors of the input vector.

Parameters:

  • x (Tensor | None, default: None ) –

    The input tensor of shape (batch_size, embedding_dim), can be omitted if labels and scores are provided directly.

  • ctx_labels (Tensor | None, default: None ) –

    The memory label tensor of shape (batch_size, num_memories) contains integer labels. If this is None, the labels will be looked up from the index based on the input tensor.

  • ctx_scores (Tensor | None, default: None ) –

    The memory score tensor of shape (batch_size, num_memories) contains float scores. If this is None, the scores will be looked up from the index based on the input tensor.

Returns:

  • Tensor

    The output tensor of shape (batch_size, num_classes), if neither x nor scores are provided the dtype will be float32, otherwise it will be the same as the scores or input tensor.

OrcaClassificationHead #

OrcaClassificationHead(
    model_dim,
    num_classes,
    num_memories,
    num_layers=1,
    num_heads=8,
    classification_mode=ClassificationMode.DIRECT,
    activation=F.relu,
    dropout=0.1,
    deep_residuals=True,
    split_retrieval_path=False,
    memory_guide_weight=0.0,
    single_lookup=True,
    database=None,
    curate_enabled=False,
    memory_index_name=None,
    drop_exact_match=None,
    exact_match_threshold=None,
    shuffle_memories=False,
    label_column_name=None,
)

Bases: OrcaLookupModule, LabelColumnNameMixin

A transformer decoder layer block that does cross attention with memory lookup

Examples:

import torch
from orcalib.orca_torch import OrcaModule, OrcaClassificationHead

class MyModule(OrcaModule):
    def __init__(self):
        super().__init__()
        self.trunk = torch.nn.Linear(10, 10)
        self.classifier = OrcaClassificationHead(model_dim=10, num_classes=5, "my_index", "my_label", num_memories=10)

    def forward(self, x):
        x = self.trunk(x) # N x 10
        x = self.classifier(x)
        return x # N x 5, e.g., where each row may become logits for a softmax

Parameters:

  • model_dim (int) –

    The dimension of the input vector and hidden layers.

  • num_classes (int) –

    The size of the output vector.

  • num_memories (int) –

    The number of memory vectors to be returned from the lookup.

  • num_layers (int, default: 1 ) –

    The number of attention blocks to be used, copies of OrcaClassificationCrossAttentionLayer.

  • num_heads (int, default: 8 ) –

    The number of heads to be used in the multi-head attention layer.

  • classification_mode (ClassificationMode, default: DIRECT ) –

    The mode of classification to be used.

  • activation (Callable[[Tensor], Tensor], default: relu ) –

    The activation function.

  • dropout (float, default: 0.1 ) –

    The dropout rate.

  • deep_residuals (bool, default: True ) –

    Whether to use deep residuals.

  • split_retrieval_path (bool, default: False ) –

    Whether to split the retrieval path.

  • memory_guide_weight (float, default: 0.0 ) –

    The weight of the memory guide.

  • single_lookup (bool, default: True ) –

    Whether to use a single lookup.

  • database (OrcaDatabase | str | None, default: None ) –

    The OrcaDatabase instance to use for lookups and curate tracking.

  • curate_enabled (bool, default: False ) –

    Whether Curate tracking is enabled.

  • memory_index_name (str | None, default: None ) –

    The name of the index to use for lookups.

  • drop_exact_match (DropExactMatchOption | None, default: None ) –

    Choose to drop the exact match (if found) always, never, or only during training or inference.

  • exact_match_threshold (float | None, default: None ) –

    Minimum similarity score for something to be considered the exact match

  • shuffle_memories (bool, default: False ) –

    Whether to shuffle the memories before returning them.

  • label_column_name (ColumnName | None, default: None ) –

    The name of the label column to return from the index.

lookup_result_transforms property writable #

lookup_result_transforms

A list of transforms to apply to the lookup result. NOTE: This will be applied even when lookup_result_override is set.

extra_lookup_column_names property writable #

extra_lookup_column_names

While set, all lookups will include these additional columns. They may inclue columns on the indexed table as well as index-specific columns, e.g., $score, $embedding.

lookup_query_override property writable #

lookup_query_override

The query to use instead of performing a lookup. NOTE: This will be ignored if lookup_result_override is also set.

lookup_result_override property writable #

lookup_result_override

The lookup result to use instead of performing a lookup.

lookup_database property writable #

lookup_database

The name of the database to use for looking up memories.

memory_index_name property writable #

memory_index_name

The name of the index to use for looking up memories.

lookup_column_names property writable #

lookup_column_names

The names of the columns to retrieve for each memory.

num_memories property writable #

num_memories

The number of memories to look up.

drop_exact_match property writable #

drop_exact_match

Whether to drop exact matches from the results.

exact_match_threshold property writable #

exact_match_threshold

The similarity threshold for exact matches.

shuffle_memories property writable #

shuffle_memories

Whether to shuffle the looked up memories.

curate_database property writable #

curate_database

The name of the database to use for saving curate tracking data.

curate_next_run_settings property writable #

curate_next_run_settings

The settings for the next curate model run.

curate_model_id property writable #

curate_model_id

The model id to associate with curated model runs.

curate_model_version property writable #

curate_model_version

The model version to associate with curated model runs.

curate_metadata property writable #

curate_metadata

The metadata to attach to curated model runs.

curate_tags property writable #

curate_tags

The tags to attach to the curated model runs.

curate_seq_id property writable #

curate_seq_id

The sequence id to associate with curated model runs.

curate_batch_size property writable #

curate_batch_size

The batch size of the model run to track curate data for, usually inferred automatically.

last_curate_run_ids property writable #

last_curate_run_ids

The run ids of the last model run for which curate tracking data was collected.

last_curate_run_settings property writable #

last_curate_run_settings

The settings of the last model run for which curate tracking data was collected.

get_effective_lookup_settings #

get_effective_lookup_settings()

Returns the effective lookup settings for this module, with any inherited settings applied.

Returns:

  • LookupSettings

    The effective lookup settings for this module. Practically, this be the lookup settings

  • LookupSettings

    set on this module. For any settings that are not set on this module, the inherited settings

  • LookupSettings

    will be used instead.

get_lookup_database_instance #

get_lookup_database_instance()

Returns the OrcaDatabase instance to use for looking up memories.

get_orca_modules_recursively #

1
2
3
get_orca_modules_recursively(
    max_depth=None, include_self=True, filter_type=None
)

Recursively yields all children of this module that are instances of the specified filter type.

  • All parent nodes will be processed before their children
  • This will search through all children —even those that are not a subclass of the filter_type— but it only returns children that are a subclass of filter_type.

Parameters:

  • max_depth (int | None, default: None ) –

    The maximum depth to search.

    • Setting this to 0 will only include this module.
    • Setting this to 1 will include only this module and its children.
    • Setting it to None (the default) will search through all modules.
    • Modules that are not of filter_type or OrcaModule do not increment the depth.
  • include_self (bool, default: True ) –

    Whether to include the current OrcaModule in the results.

  • filter_type (type[ORCA_MODULE_TYPE] | None, default: None ) –

    The subtype of OrcaModule to filter for. If None, any subtypes of OrcaModule will be returned.

Yields:

  • ORCA_MODULE_TYPE

    modules of type filter_type that are used in the children of this module.

enable_curate #

enable_curate(recursive=True)

Enable Curate tracking for the model and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to enable Curate tracking recursively.

disable_curate #

disable_curate(recursive=True)

Disable Curate tracking for this module and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to disable Curate tracking recursively.

update_curate_settings #

update_curate_settings(
    model_id=None,
    model_version=None,
    tags=None,
    extra_tags=None,
    metadata=None,
    extra_metadata=None,
    batch_size=None,
    seq_id=None,
    enabled=None,
    enable_recursive=True,
)

Update curate tracking settings for the module and all its children.

Parameters:

  • model_id (str | None, default: None ) –

    The ID of the model.

  • model_version (str | None, default: None ) –

    The version of the model.

  • tags (Iterable[str] | None, default: None ) –

    The new tags to be added to the model.

  • extra_tags (Iterable[str] | None, default: None ) –

    The extra tags to be added to the model.

  • metadata (OrcaMetadataDict | None, default: None ) –

    The new metadata to be added to the model.

  • extra_metadata (OrcaMetadataDict | None, default: None ) –

    The extra metadata to be added to the model.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the model.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the model.

record_next_model_memory_lookups #

1
2
3
record_next_model_memory_lookups(
    tags=None, metadata=None, batch_size=None, seq_id=None
)

Sets up curate tracking for the memory lookups during the next forward pass only.

Parameters:

  • tags (Iterable[str] | None, default: None ) –

    Additional tags to be recorded on the next model run.

  • metadata (OrcaMetadataDict | None, default: None ) –

    Additional metadata to be recorded on the next model run.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the next model run.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the next model run.

record_model_feedback #

1
2
3
record_model_feedback(
    val, name="default", kind=FeedbackKind.CONTINUOUS
)

Records feedback for the last model runs for which memory lookups were recorded by curate.

Parameters:

  • val (list[float] | float | int | list[int]) –

    The feedback to be recorded.

  • name (str, default: 'default' ) –

    The name of the feedback.

  • kind (FeedbackKind, default: CONTINUOUS ) –

    The kind of feedback.

record_model_input_output #

record_model_input_output(inputs, outputs)

Records the inputs and outputs of the last model runs for which memory lookups were recorded by curate.

Parameters:

  • inputs (list[Any] | Any) –

    The inputs to be recorded.

  • outputs (list[Any] | Any) –

    The outputs to be recorded.

get_lookup_setting_summary #

get_lookup_setting_summary()

Returns a summary of the lookup settings for each OrcaLookupLayer in this module and its descendants.

forward #

forward(x, ctx=None, labels=None, memory_key=None)

Generate logits based on the input vector and memory context.

Parameters:

  • x (Tensor) –

    The input tensor of shape (batch_size, embedding_dim)

  • ctx (Tensor | None, default: None ) –

    The memory context tensor of shape (batch_size, num_memories, embedding_dim). If None, the memory context will be looked up based on the memory_key or input tensor.

  • labels (Tensor | None, default: None ) –

    The memory label tensor of shape (batch_size, num_memories) containing integer labels. If None, the labels will be looked up from the index based on the memory_key or input tensor.

  • memory_key (Tensor | None, default: None ) –

    The memory key tensor of shape (batch_size, embedding_dim) to use for lookup. If None, the input tensor will be used.

Returns:

  • Tensor

    The logits tensor of shape (batch_size, num_classes)

OrcaMoeClassificationHead #

OrcaMoeClassificationHead(
    model_dim,
    num_classes,
    num_memories,
    label_column_name,
    gate_layers=1,
    hidden_layers=0,
    activation=F.relu,
    dropout=0.1,
    database=None,
    curate_enabled=False,
    memory_index_name=None,
    drop_exact_match=DropExactMatchOption.TRAINING_ONLY,
    exact_match_threshold=None,
)

Bases: OrcaLookupModule

A mixture of experts classification head that combines a KNN classifier with a linear classifier.

Examples:

import torch
from orcalib import OrcaModel, OrcaMoeClassificationHead

class MyModule(OrcaModel):
    def __init__(self):
        super().__init__()
        self.trunk = torch.nn.Linear(10, 10)
        self.classifier = OrcaMoeClassificationHead(model_dim=10, num_classes=5, num_memories=10)

    def forward(self, x):
        x = self.trunk(x)
        x = self.classifier(x)
        return x

Parameters:

  • model_dim (int) –

    The dimension of the input vector and hidden layers.

  • num_classes (int) –

    The size of the output vector.

  • num_memories (int) –

    The number of memory vectors to be returned from the lookup.

  • label_column_name (ColumnName) –

    The name of the label column to return from the index.

  • gate_layers (int, default: 1 ) –

    The number of layers to use in the gating network.

  • hidden_layers (int, default: 0 ) –

    The number of hidden layers to use in the linear classifier.

  • activation (Callable[[Tensor], Tensor], default: relu ) –

    The activation function.

  • dropout (float, default: 0.1 ) –

    The dropout rate.

  • database (OrcaDatabase | str | None, default: None ) –

    The OrcaDatabase instance to use for lookups and curate tracking.

  • curate_enabled (bool, default: False ) –

    Whether Curate tracking is enabled.

  • memory_index_name (str | None, default: None ) –

    The name of the index to use for lookups.

  • drop_exact_match (DropExactMatchOption, default: TRAINING_ONLY ) –

    Choose to drop the exact match (if found) always, never, or only during training or inference.

  • exact_match_threshold (float | None, default: None ) –

    Minimum similarity score for something to be considered the exact match

lookup_result_transforms property writable #

lookup_result_transforms

A list of transforms to apply to the lookup result. NOTE: This will be applied even when lookup_result_override is set.

extra_lookup_column_names property writable #

extra_lookup_column_names

While set, all lookups will include these additional columns. They may inclue columns on the indexed table as well as index-specific columns, e.g., $score, $embedding.

lookup_query_override property writable #

lookup_query_override

The query to use instead of performing a lookup. NOTE: This will be ignored if lookup_result_override is also set.

lookup_result_override property writable #

lookup_result_override

The lookup result to use instead of performing a lookup.

lookup_database property writable #

lookup_database

The name of the database to use for looking up memories.

memory_index_name property writable #

memory_index_name

The name of the index to use for looking up memories.

lookup_column_names property writable #

lookup_column_names

The names of the columns to retrieve for each memory.

num_memories property writable #

num_memories

The number of memories to look up.

drop_exact_match property writable #

drop_exact_match

Whether to drop exact matches from the results.

exact_match_threshold property writable #

exact_match_threshold

The similarity threshold for exact matches.

shuffle_memories property writable #

shuffle_memories

Whether to shuffle the looked up memories.

curate_database property writable #

curate_database

The name of the database to use for saving curate tracking data.

curate_next_run_settings property writable #

curate_next_run_settings

The settings for the next curate model run.

curate_model_id property writable #

curate_model_id

The model id to associate with curated model runs.

curate_model_version property writable #

curate_model_version

The model version to associate with curated model runs.

curate_metadata property writable #

curate_metadata

The metadata to attach to curated model runs.

curate_tags property writable #

curate_tags

The tags to attach to the curated model runs.

curate_seq_id property writable #

curate_seq_id

The sequence id to associate with curated model runs.

curate_batch_size property writable #

curate_batch_size

The batch size of the model run to track curate data for, usually inferred automatically.

last_curate_run_ids property writable #

last_curate_run_ids

The run ids of the last model run for which curate tracking data was collected.

last_curate_run_settings property writable #

last_curate_run_settings

The settings of the last model run for which curate tracking data was collected.

get_effective_lookup_settings #

get_effective_lookup_settings()

Returns the effective lookup settings for this module, with any inherited settings applied.

Returns:

  • LookupSettings

    The effective lookup settings for this module. Practically, this be the lookup settings

  • LookupSettings

    set on this module. For any settings that are not set on this module, the inherited settings

  • LookupSettings

    will be used instead.

get_lookup_database_instance #

get_lookup_database_instance()

Returns the OrcaDatabase instance to use for looking up memories.

get_orca_modules_recursively #

1
2
3
get_orca_modules_recursively(
    max_depth=None, include_self=True, filter_type=None
)

Recursively yields all children of this module that are instances of the specified filter type.

  • All parent nodes will be processed before their children
  • This will search through all children —even those that are not a subclass of the filter_type— but it only returns children that are a subclass of filter_type.

Parameters:

  • max_depth (int | None, default: None ) –

    The maximum depth to search.

    • Setting this to 0 will only include this module.
    • Setting this to 1 will include only this module and its children.
    • Setting it to None (the default) will search through all modules.
    • Modules that are not of filter_type or OrcaModule do not increment the depth.
  • include_self (bool, default: True ) –

    Whether to include the current OrcaModule in the results.

  • filter_type (type[ORCA_MODULE_TYPE] | None, default: None ) –

    The subtype of OrcaModule to filter for. If None, any subtypes of OrcaModule will be returned.

Yields:

  • ORCA_MODULE_TYPE

    modules of type filter_type that are used in the children of this module.

enable_curate #

enable_curate(recursive=True)

Enable Curate tracking for the model and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to enable Curate tracking recursively.

disable_curate #

disable_curate(recursive=True)

Disable Curate tracking for this module and (if recursive is True) for all its descendants.

Parameters:

  • recursive (bool, default: True ) –

    Whether to disable Curate tracking recursively.

update_curate_settings #

update_curate_settings(
    model_id=None,
    model_version=None,
    tags=None,
    extra_tags=None,
    metadata=None,
    extra_metadata=None,
    batch_size=None,
    seq_id=None,
    enabled=None,
    enable_recursive=True,
)

Update curate tracking settings for the module and all its children.

Parameters:

  • model_id (str | None, default: None ) –

    The ID of the model.

  • model_version (str | None, default: None ) –

    The version of the model.

  • tags (Iterable[str] | None, default: None ) –

    The new tags to be added to the model.

  • extra_tags (Iterable[str] | None, default: None ) –

    The extra tags to be added to the model.

  • metadata (OrcaMetadataDict | None, default: None ) –

    The new metadata to be added to the model.

  • extra_metadata (OrcaMetadataDict | None, default: None ) –

    The extra metadata to be added to the model.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the model.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the model.

record_next_model_memory_lookups #

1
2
3
record_next_model_memory_lookups(
    tags=None, metadata=None, batch_size=None, seq_id=None
)

Sets up curate tracking for the memory lookups during the next forward pass only.

Parameters:

  • tags (Iterable[str] | None, default: None ) –

    Additional tags to be recorded on the next model run.

  • metadata (OrcaMetadataDict | None, default: None ) –

    Additional metadata to be recorded on the next model run.

  • batch_size (int | None, default: None ) –

    The batch size to be used for the next model run.

  • seq_id (UUID | None, default: None ) –

    The sequence ID to be used for the next model run.

record_model_feedback #

1
2
3
record_model_feedback(
    val, name="default", kind=FeedbackKind.CONTINUOUS
)

Records feedback for the last model runs for which memory lookups were recorded by curate.

Parameters:

  • val (list[float] | float | int | list[int]) –

    The feedback to be recorded.

  • name (str, default: 'default' ) –

    The name of the feedback.

  • kind (FeedbackKind, default: CONTINUOUS ) –

    The kind of feedback.

record_model_input_output #

record_model_input_output(inputs, outputs)

Records the inputs and outputs of the last model runs for which memory lookups were recorded by curate.

Parameters:

  • inputs (list[Any] | Any) –

    The inputs to be recorded.

  • outputs (list[Any] | Any) –

    The outputs to be recorded.

get_lookup_setting_summary #

get_lookup_setting_summary()

Returns a summary of the lookup settings for each OrcaLookupLayer in this module and its descendants.

forward #

forward(x, ctx_scores=None, ctx_labels=None)

Generate logits based on the input vector and memory context.

Parameters:

  • x (Tensor) –

    The input tensor of shape (batch_size, embedding_dim).

  • ctx_scores (Tensor | None, default: None ) –

    The memory scores tensor of shape (batch_size, num_memories). If this is None, the scores will be looked up from the index based on the input tensor.

  • ctx_labels (Tensor | None, default: None ) –

    The memory labels tensor of shape (batch_size, num_memories). If this is None, the labels will be looked up from the index based on the input tensor

Returns:

  • Tensor

    The logits tensor of shape (batch_size, num_classes).