Version: 3.x

rasa.engine.training.hooks

TrainingHook Objects

class TrainingHook(GraphNodeHook)

Caches fingerprints and outputs of nodes during model training.

__init__

def __init__(cache: TrainingCache, model_storage: ModelStorage,
pruned_schema: GraphSchema) -> None

Initializes a TrainingHook.

Arguments:

  • cache - Cache used to store fingerprints and outputs.
  • model_storage - Used to cache Resources.
  • pruned_schema - The pruned training schema.

on_before_node

def on_before_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any],
received_inputs: Dict[Text, Any]) -> Dict

Calculates the run fingerprint for use in on_after_node.

on_after_node

def on_after_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any], output: Any,
input_hook_data: Dict) -> None

Stores the fingerprints and caches the output of the node.

LoggingHook Objects

class LoggingHook(GraphNodeHook)

Logs the training of components.

__init__

def __init__(pruned_schema: GraphSchema) -> None

Creates hook.

Arguments:

  • pruned_schema - The pruned schema provides us with the information whether a component is cached or not.

on_before_node

def on_before_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any],
received_inputs: Dict[Text, Any]) -> Dict

Logs the training start of a graph node.

on_after_node

def on_after_node(node_name: Text, execution_context: ExecutionContext,
config: Dict[Text, Any], output: Any,
input_hook_data: Dict) -> None

Logs when a component finished its training.