import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import numpy as np @tf.custom_gradient def gradient_reverse(x, lambd=1.0): """ Gradient Reversal Layer (GRL) for TensorFlow Forward: identity Backward: multiply incoming gradient by -lambda """ def grad(dy): return -lambd * dy # Only return gradient w.r.t. x, not lambd return tf.identity(x), grad class NoiseModel(keras.Model): """ Noise Model: 2-layer GRU that learns to estimate noise in the neural data TensorFlow/Keras implementation optimized for TPU v5e-8 """ def __init__(self, neural_dim, n_units, n_days, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0, **kwargs): super(NoiseModel, self).__init__(**kwargs) self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Day-specific input layers - use Variables for TPU compatibility self.day_layer_activation = layers.Activation('softsign') # Initialize day-specific weights and biases as Variables self.day_weights = [] self.day_biases = [] for i in range(n_days): weight = self.add_weight( name=f'day_weight_{i}', shape=(neural_dim, neural_dim), initializer=tf.keras.initializers.Identity(), trainable=True ) bias = self.add_weight( name=f'day_bias_{i}', shape=(neural_dim,), initializer=tf.keras.initializers.Zeros(), trainable=True ) self.day_weights.append(weight) self.day_biases.append(bias) self.day_layer_dropout = layers.Dropout(input_dropout) # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 2-layer GRU for noise estimation # Use separate GRU layers for better TPU performance self.gru1 = layers.GRU( units=self.input_size, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, # Avoid recurrent dropout on TPU kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='noise_gru1' ) self.gru2 = layers.GRU( units=self.input_size, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='noise_gru2' ) # Learnable initial hidden states self.h0_1 = self.add_weight( name='h0_1', shape=(1, self.input_size), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) self.h0_2 = self.add_weight( name='h0_2', shape=(1, self.input_size), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) def call(self, x, day_idx, states=None, training=None): """ Forward pass optimized for TPU compilation Args: x: Input tensor [batch_size, time_steps, neural_dim] day_idx: Day indices [batch_size] states: Optional initial states training: Training mode flag """ batch_size = tf.shape(x)[0] # Stack all day weights and biases for efficient gathering all_day_weights = tf.stack(self.day_weights, axis=0) # [n_days, neural_dim, neural_dim] all_day_biases = tf.stack(self.day_biases, axis=0) # [n_days, neural_dim] # Gather day-specific parameters day_weights = tf.gather(all_day_weights, day_idx) # [batch_size, neural_dim, neural_dim] day_biases = tf.gather(all_day_biases, day_idx) # [batch_size, neural_dim] # Add time dimension to biases for broadcasting day_biases = tf.expand_dims(day_biases, axis=1) # [batch_size, 1, neural_dim] # Apply day-specific transformation using efficient batch matrix multiplication x = tf.linalg.matmul(x, day_weights) + day_biases x = self.day_layer_activation(x) # Apply input dropout if training and self.input_dropout > 0: x = self.day_layer_dropout(x, training=training) # Apply patch processing if enabled if self.patch_size > 0: x = self._apply_patch_processing(x) # Initialize hidden states if not provided if states is None: h1_init = tf.tile(self.h0_1, [batch_size, 1]) # [batch_size, input_size] h2_init = tf.tile(self.h0_2, [batch_size, 1]) # [batch_size, input_size] states = [h1_init, h2_init] else: h1_init, h2_init = states # Two-layer GRU forward pass output1, h1_final = self.gru1(x, initial_state=h1_init, training=training) output, h2_final = self.gru2(output1, initial_state=h2_init, training=training) return output, [h1_final, h2_final] def _apply_patch_processing(self, x): """Apply patch processing using TensorFlow operations""" batch_size = tf.shape(x)[0] time_steps = tf.shape(x)[1] # Add channel dimension for conv1d operations x = tf.expand_dims(x, axis=2) # [batch_size, time_steps, 1, neural_dim] # Extract patches using extract_patches # This is equivalent to PyTorch's unfold operation patch_x = tf.image.extract_patches( x, sizes=[1, self.patch_size, 1, 1], strides=[1, self.patch_stride, 1, 1], rates=[1, 1, 1, 1], padding='VALID' ) # Reshape to match expected output new_time_steps = tf.shape(patch_x)[1] patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1]) return patch_x class CleanSpeechModel(keras.Model): """ Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition TensorFlow/Keras implementation optimized for TPU v5e-8 """ def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0, **kwargs): super(CleanSpeechModel, self).__init__(**kwargs) self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.n_classes = n_classes self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Day-specific input layers self.day_layer_activation = layers.Activation('softsign') # Initialize day-specific weights and biases self.day_weights = [] self.day_biases = [] for i in range(n_days): weight = self.add_weight( name=f'day_weight_{i}', shape=(neural_dim, neural_dim), initializer=tf.keras.initializers.Identity(), trainable=True ) bias = self.add_weight( name=f'day_bias_{i}', shape=(neural_dim,), initializer=tf.keras.initializers.Zeros(), trainable=True ) self.day_weights.append(weight) self.day_biases.append(bias) self.day_layer_dropout = layers.Dropout(input_dropout) # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 3-layer GRU for clean speech recognition self.gru1 = layers.GRU( units=n_units, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='clean_gru1' ) self.gru2 = layers.GRU( units=n_units, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='clean_gru2' ) self.gru3 = layers.GRU( units=n_units, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='clean_gru3' ) # Output classification layer self.output_layer = layers.Dense( n_classes, kernel_initializer=tf.keras.initializers.GlorotUniform(), name='clean_output' ) # Learnable initial hidden states self.h0_1 = self.add_weight( name='h0_1', shape=(1, n_units), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) self.h0_2 = self.add_weight( name='h0_2', shape=(1, n_units), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) self.h0_3 = self.add_weight( name='h0_3', shape=(1, n_units), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) def call(self, x, day_idx, states=None, return_state=False, training=None): """Forward pass optimized for TPU compilation""" batch_size = tf.shape(x)[0] # Stack all day weights and biases for efficient gathering all_day_weights = tf.stack(self.day_weights, axis=0) all_day_biases = tf.stack(self.day_biases, axis=0) # Gather day-specific parameters day_weights = tf.gather(all_day_weights, day_idx) day_biases = tf.gather(all_day_biases, day_idx) day_biases = tf.expand_dims(day_biases, axis=1) # Apply day-specific transformation x = tf.linalg.matmul(x, day_weights) + day_biases x = self.day_layer_activation(x) if training and self.input_dropout > 0: x = self.day_layer_dropout(x, training=training) # Apply patch processing if enabled if self.patch_size > 0: x = self._apply_patch_processing(x) # Initialize hidden states if not provided if states is None: h1_init = tf.tile(self.h0_1, [batch_size, 1]) h2_init = tf.tile(self.h0_2, [batch_size, 1]) h3_init = tf.tile(self.h0_3, [batch_size, 1]) states = [h1_init, h2_init, h3_init] else: h1_init, h2_init, h3_init = states # Three-layer GRU forward pass output1, h1_final = self.gru1(x, initial_state=h1_init, training=training) output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training) output, h3_final = self.gru3(output2, initial_state=h3_init, training=training) # Classification logits = self.output_layer(output) if return_state: return logits, [h1_final, h2_final, h3_final] return logits def _apply_patch_processing(self, x): """Apply patch processing using TensorFlow operations""" batch_size = tf.shape(x)[0] # Add channel dimension x = tf.expand_dims(x, axis=2) # Extract patches patch_x = tf.image.extract_patches( x, sizes=[1, self.patch_size, 1, 1], strides=[1, self.patch_stride, 1, 1], rates=[1, 1, 1, 1], padding='VALID' ) # Reshape new_time_steps = tf.shape(patch_x)[1] patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1]) return patch_x class NoisySpeechModel(keras.Model): """ Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition TensorFlow/Keras implementation optimized for TPU v5e-8 """ def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0, **kwargs): super(NoisySpeechModel, self).__init__(**kwargs) self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.n_classes = n_classes self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 2-layer GRU for noisy speech recognition self.gru1 = layers.GRU( units=n_units, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='noisy_gru1' ) self.gru2 = layers.GRU( units=n_units, return_sequences=True, return_state=True, dropout=self.rnn_dropout, recurrent_dropout=0.0, kernel_initializer=tf.keras.initializers.GlorotUniform(), recurrent_initializer=tf.keras.initializers.Orthogonal(), name='noisy_gru2' ) # Output classification layer self.output_layer = layers.Dense( n_classes, kernel_initializer=tf.keras.initializers.GlorotUniform(), name='noisy_output' ) # Learnable initial hidden states self.h0_1 = self.add_weight( name='h0_1', shape=(1, n_units), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) self.h0_2 = self.add_weight( name='h0_2', shape=(1, n_units), initializer=tf.keras.initializers.GlorotUniform(), trainable=True ) def call(self, x, states=None, return_state=False, training=None): """Forward pass - no day-specific layers for noise processing""" batch_size = tf.shape(x)[0] # Initialize hidden states if not provided if states is None: h1_init = tf.tile(self.h0_1, [batch_size, 1]) h2_init = tf.tile(self.h0_2, [batch_size, 1]) states = [h1_init, h2_init] else: h1_init, h2_init = states # Two-layer GRU forward pass output1, h1_final = self.gru1(x, initial_state=h1_init, training=training) output, h2_final = self.gru2(output1, initial_state=h2_init, training=training) # Classification logits = self.output_layer(output) if return_state: return logits, [h1_final, h2_final] return logits class TripleGRUDecoder(keras.Model): """ Three-model adversarial architecture for neural speech decoding TensorFlow/Keras implementation optimized for TPU v5e-8 Combines: - NoiseModel: estimates noise in neural data - CleanSpeechModel: processes denoised signal for recognition - NoisySpeechModel: processes noise signal for recognition """ def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0, **kwargs): super(TripleGRUDecoder, self).__init__(**kwargs) self.neural_dim = neural_dim self.n_units = n_units self.n_classes = n_classes self.n_days = n_days self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Create the three models self.noise_model = NoiseModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride, name='noise_model' ) self.clean_speech_model = CleanSpeechModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride, name='clean_speech_model' ) self.noisy_speech_model = NoisySpeechModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride, name='noisy_speech_model' ) # Training mode flag self.training_mode = 'full' # 'full', 'inference' def _apply_preprocessing(self, x, day_idx): """Apply preprocessing using clean speech model's day layers""" batch_size = tf.shape(x)[0] # Stack all day weights and biases all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0) all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0) # Gather day-specific parameters day_weights = tf.gather(all_day_weights, day_idx) day_biases = tf.gather(all_day_biases, day_idx) day_biases = tf.expand_dims(day_biases, axis=1) # Apply transformation x_processed = tf.linalg.matmul(x, day_weights) + day_biases x_processed = self.clean_speech_model.day_layer_activation(x_processed) # Apply patch processing if enabled if self.patch_size > 0: x_processed = self.clean_speech_model._apply_patch_processing(x_processed) return x_processed def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None): """Forward pass for CleanSpeechModel with already processed input""" batch_size = tf.shape(x_processed)[0] # Initialize hidden states if not provided if states is None: h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1]) h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1]) h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1]) states = [h1_init, h2_init, h3_init] else: h1_init, h2_init, h3_init = states # GRU forward pass (skip preprocessing since input is already processed) output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training) output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training) output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training) # Classification logits = self.clean_speech_model.output_layer(output) return logits def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None): """Forward pass for NoisySpeechModel with already processed input""" batch_size = tf.shape(x_processed)[0] # Initialize hidden states if not provided if states is None: h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1]) h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1]) states = [h1_init, h2_init] else: h1_init, h2_init = states # GRU forward pass output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training) output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training) # Classification logits = self.noisy_speech_model.output_layer(output) return logits def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None): """ Three-model adversarial forward pass optimized for TPU compilation Args: x: Input tensor [batch_size, time_steps, neural_dim] day_idx: Day indices [batch_size] states: Dictionary with 'noise', 'clean', 'noisy' states or None return_state: Whether to return hidden states mode: 'full' for training (all three models), 'inference' for inference grl_lambda: Gradient reversal strength for adversarial training training: Training mode flag """ if mode == 'full': # Training mode: run all three models # 1. Noise model estimates noise in the data noise_output, noise_hidden = self.noise_model( x, day_idx, states['noise'] if states else None, training=training ) # 2. Apply preprocessing to get x in the same space as noise_output x_processed = self._apply_preprocessing(x, day_idx) # 3. Clean speech model processes denoised signal denoised_input = x_processed - noise_output # Residual connection clean_logits = self._clean_forward_with_processed_input( denoised_input, day_idx, states['clean'] if states else None, training=training ) # 4. Noisy speech model processes noise signal # Apply Gradient Reversal Layer if specified if grl_lambda > 0.0: noisy_input = gradient_reverse(noise_output, grl_lambda) else: noisy_input = noise_output noisy_logits = self._noisy_forward_with_processed_input( noisy_input, states['noisy'] if states else None, training=training ) # Return results if return_state: return (clean_logits, noisy_logits, noise_output), noise_hidden return clean_logits, noisy_logits, noise_output elif mode == 'inference': # Inference mode: only noise model + clean speech model # 1. Estimate noise noise_output, noise_hidden = self.noise_model( x, day_idx, states['noise'] if states else None, training=training ) # 2. Apply preprocessing for residual connection x_processed = self._apply_preprocessing(x, day_idx) denoised_input = x_processed - noise_output clean_logits = self._clean_forward_with_processed_input( denoised_input, day_idx, states['clean'] if states else None, training=training ) # Return results if return_state: return clean_logits, noise_hidden return clean_logits else: raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'") def set_mode(self, mode): """Set the operating mode""" self.training_mode = mode # Custom CTC Loss for TensorFlow TPU class CTCLoss(keras.losses.Loss): """ Custom CTC Loss optimized for TPU v5e-8 """ def __init__(self, blank_index=0, reduction='none', **kwargs): super(CTCLoss, self).__init__(reduction=reduction, **kwargs) self.blank_index = blank_index def call(self, y_true, y_pred): """ Args: y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths' y_pred: Logits tensor [batch_size, time_steps, num_classes] """ labels = y_true['labels'] input_lengths = y_true['input_lengths'] label_lengths = y_true['label_lengths'] # Ensure correct data types labels = tf.cast(labels, tf.int32) input_lengths = tf.cast(input_lengths, tf.int32) label_lengths = tf.cast(label_lengths, tf.int32) # Convert logits to log probabilities log_probs = tf.nn.log_softmax(y_pred, axis=-1) # Transpose for CTC: [time_steps, batch_size, num_classes] log_probs = tf.transpose(log_probs, [1, 0, 2]) # Convert dense labels to sparse format for CTC using TensorFlow operations def dense_to_sparse(dense_tensor, sequence_lengths): """Convert dense tensor to sparse tensor for CTC""" batch_size = tf.shape(dense_tensor)[0] max_len = tf.shape(dense_tensor)[1] # Create mask for non-zero elements mask = tf.not_equal(dense_tensor, 0) # Get indices of non-zero elements indices = tf.where(mask) # Get values at those indices values = tf.gather_nd(dense_tensor, indices) # Create sparse tensor dense_shape = tf.cast([batch_size, max_len], tf.int64) return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape) # Convert labels to sparse format sparse_labels = dense_to_sparse(labels, label_lengths) # Compute CTC loss loss = tf.nn.ctc_loss( labels=sparse_labels, logits=log_probs, label_length=None, # Not needed for sparse format logit_length=input_lengths, blank_index=self.blank_index, logits_time_major=True ) return loss # TPU Strategy Helper Functions def create_tpu_strategy(): """Create TPU strategy for distributed training on TPU v5e-8""" import os print("🔍 Detecting TPU environment...") # Disable GPU to avoid CUDA conflicts in TPU environment try: print("🚫 Disabling GPU to prevent CUDA conflicts...") tf.config.set_visible_devices([], 'GPU') print("✅ GPU disabled successfully") except Exception as e: print(f"⚠️ Warning: Could not disable GPU: {e}") # Check for various TPU environment variables tpu_address = None tpu_name = None # Check common TPU environment variables if 'COLAB_TPU_ADDR' in os.environ: tpu_address = os.environ['COLAB_TPU_ADDR'] print(f"📍 Found Colab TPU address: {tpu_address}") elif 'TPU_NAME' in os.environ: tpu_name = os.environ['TPU_NAME'] print(f"📍 Found TPU name: {tpu_name}") elif 'TPU_WORKER_ID' in os.environ: # Kaggle TPU environment worker_id = os.environ.get('TPU_WORKER_ID', '0') tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address print(f"📍 Kaggle TPU detected, worker ID: {worker_id}, address: {tpu_address}") # Print all TPU-related environment variables for debugging print("🔧 TPU environment variables:") tpu_vars = {k: v for k, v in os.environ.items() if 'TPU' in k or 'COLAB' in k} for key, value in tpu_vars.items(): print(f" {key}={value}") try: # Use official TPU initialization pattern (simplified and reliable) print("🚀 Using official TensorFlow TPU initialization...") # Use your tested official TPU initialization code resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) # This is the TPU initialization code that has to be at the beginning. tf.tpu.experimental.initialize_tpu_system(resolver) # Verify TPU devices (following official example) tpu_devices = tf.config.list_logical_devices('TPU') print("All devices: ", tpu_devices) if not tpu_devices: raise RuntimeError("No TPU devices found!") print(f"✅ Found {len(tpu_devices)} TPU devices") # Create TPU strategy print("🎯 Creating TPU strategy...") strategy = tf.distribute.TPUStrategy(resolver) print(f"✅ TPU initialized successfully!") print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}") return strategy except Exception as e: print(f"❌ Failed to initialize TPU: {e}") print(f"🔍 Error type: {type(e).__name__}") # Enhanced error reporting if "Please provide a TPU Name" in str(e): print("💡 Hint: TPU name/address not found in environment variables") print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID") print("🔄 Falling back to default strategy (CPU/GPU)") return tf.distribute.get_strategy() def build_model_for_tpu(config): """ Build TripleGRUDecoder model optimized for TPU v5e-8 Args: config: Configuration dictionary containing model parameters Returns: Compiled Keras model ready for TPU training """ model = TripleGRUDecoder( neural_dim=config['model']['n_input_features'], n_units=config['model']['n_units'], n_days=len(config['dataset']['sessions']), n_classes=config['dataset']['n_classes'], rnn_dropout=config['model']['rnn_dropout'], input_dropout=config['model']['input_network']['input_layer_dropout'], patch_size=config['model']['patch_size'], patch_stride=config['model']['patch_stride'] ) return model # Mixed Precision Configuration for TPU v5e-8 def configure_mixed_precision(): """Configure mixed precision for optimal TPU v5e-8 performance""" policy = keras.mixed_precision.Policy('mixed_bfloat16') keras.mixed_precision.set_global_policy(policy) print(f"Mixed precision policy set to: {policy.name}") return policy