810 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			810 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import tensorflow as tf
 | |
| from tensorflow import keras
 | |
| from tensorflow.keras import layers
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| @tf.custom_gradient
 | |
| def gradient_reverse(x, lambd=1.0):
 | |
|     """
 | |
|     Gradient Reversal Layer (GRL) for TensorFlow
 | |
|     Forward: identity
 | |
|     Backward: multiply incoming gradient by -lambda
 | |
|     """
 | |
|     def grad(dy):
 | |
|         return -lambd * dy  # Only return gradient w.r.t. x, not lambd
 | |
| 
 | |
|     return tf.identity(x), grad
 | |
| 
 | |
| 
 | |
| class NoiseModel(keras.Model):
 | |
|     """
 | |
|     Noise Model: 2-layer GRU that learns to estimate noise in the neural data
 | |
|     TensorFlow/Keras implementation optimized for TPU v5e-8
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  neural_dim,
 | |
|                  n_units,
 | |
|                  n_days,
 | |
|                  rnn_dropout=0.0,
 | |
|                  input_dropout=0.0,
 | |
|                  patch_size=0,
 | |
|                  patch_stride=0,
 | |
|                  **kwargs):
 | |
|         super(NoiseModel, self).__init__(**kwargs)
 | |
| 
 | |
|         self.neural_dim = neural_dim
 | |
|         self.n_units = n_units
 | |
|         self.n_days = n_days
 | |
|         self.rnn_dropout = rnn_dropout
 | |
|         self.input_dropout = input_dropout
 | |
|         self.patch_size = patch_size
 | |
|         self.patch_stride = patch_stride
 | |
| 
 | |
|         # Day-specific input layers - use Variables for TPU compatibility
 | |
|         self.day_layer_activation = layers.Activation('softsign')
 | |
| 
 | |
|         # Initialize day-specific weights and biases as Variables
 | |
|         self.day_weights = []
 | |
|         self.day_biases = []
 | |
|         for i in range(n_days):
 | |
|             weight = self.add_weight(
 | |
|                 name=f'day_weight_{i}',
 | |
|                 shape=(neural_dim, neural_dim),
 | |
|                 initializer=tf.keras.initializers.Identity(),
 | |
|                 trainable=True
 | |
|             )
 | |
|             bias = self.add_weight(
 | |
|                 name=f'day_bias_{i}',
 | |
|                 shape=(neural_dim,),
 | |
|                 initializer=tf.keras.initializers.Zeros(),
 | |
|                 trainable=True
 | |
|             )
 | |
|             self.day_weights.append(weight)
 | |
|             self.day_biases.append(bias)
 | |
| 
 | |
|         self.day_layer_dropout = layers.Dropout(input_dropout)
 | |
| 
 | |
|         # Calculate input size after patching
 | |
|         self.input_size = self.neural_dim
 | |
|         if self.patch_size > 0:
 | |
|             self.input_size *= self.patch_size
 | |
| 
 | |
|         # 2-layer GRU for noise estimation
 | |
|         # Use separate GRU layers for better TPU performance
 | |
|         self.gru1 = layers.GRU(
 | |
|             units=self.input_size,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,  # Avoid recurrent dropout on TPU
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='noise_gru1'
 | |
|         )
 | |
| 
 | |
|         self.gru2 = layers.GRU(
 | |
|             units=self.input_size,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='noise_gru2'
 | |
|         )
 | |
| 
 | |
|         # Learnable initial hidden states
 | |
|         self.h0_1 = self.add_weight(
 | |
|             name='h0_1',
 | |
|             shape=(1, self.input_size),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
|         self.h0_2 = self.add_weight(
 | |
|             name='h0_2',
 | |
|             shape=(1, self.input_size),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
| 
 | |
|     def call(self, x, day_idx, states=None, training=None):
 | |
|         """
 | |
|         Forward pass optimized for TPU compilation
 | |
| 
 | |
|         Args:
 | |
|             x: Input tensor [batch_size, time_steps, neural_dim]
 | |
|             day_idx: Day indices [batch_size]
 | |
|             states: Optional initial states
 | |
|             training: Training mode flag
 | |
|         """
 | |
|         batch_size = tf.shape(x)[0]
 | |
| 
 | |
|         # Stack all day weights and biases for efficient gathering
 | |
|         all_day_weights = tf.stack(self.day_weights, axis=0)  # [n_days, neural_dim, neural_dim]
 | |
|         all_day_biases = tf.stack(self.day_biases, axis=0)    # [n_days, neural_dim]
 | |
| 
 | |
|         # Gather day-specific parameters
 | |
|         day_weights = tf.gather(all_day_weights, day_idx)  # [batch_size, neural_dim, neural_dim]
 | |
|         day_biases = tf.gather(all_day_biases, day_idx)    # [batch_size, neural_dim]
 | |
| 
 | |
|         # Add time dimension to biases for broadcasting
 | |
|         day_biases = tf.expand_dims(day_biases, axis=1)  # [batch_size, 1, neural_dim]
 | |
| 
 | |
|         # Apply day-specific transformation using efficient batch matrix multiplication
 | |
|         x = tf.linalg.matmul(x, day_weights) + day_biases
 | |
|         x = self.day_layer_activation(x)
 | |
| 
 | |
|         # Apply input dropout
 | |
|         if training and self.input_dropout > 0:
 | |
|             x = self.day_layer_dropout(x, training=training)
 | |
| 
 | |
|         # Apply patch processing if enabled
 | |
|         if self.patch_size > 0:
 | |
|             x = self._apply_patch_processing(x)
 | |
| 
 | |
|         # Initialize hidden states if not provided
 | |
|         if states is None:
 | |
|             h1_init = tf.tile(self.h0_1, [batch_size, 1])  # [batch_size, input_size]
 | |
|             h2_init = tf.tile(self.h0_2, [batch_size, 1])  # [batch_size, input_size]
 | |
|             states = [h1_init, h2_init]
 | |
|         else:
 | |
|             h1_init, h2_init = states
 | |
| 
 | |
|         # Two-layer GRU forward pass
 | |
|         output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
 | |
|         output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
 | |
| 
 | |
|         return output, [h1_final, h2_final]
 | |
| 
 | |
|     def _apply_patch_processing(self, x):
 | |
|         """Apply patch processing using TensorFlow operations"""
 | |
|         batch_size = tf.shape(x)[0]
 | |
|         time_steps = tf.shape(x)[1]
 | |
| 
 | |
|         # Add channel dimension for conv1d operations
 | |
|         x = tf.expand_dims(x, axis=2)  # [batch_size, time_steps, 1, neural_dim]
 | |
| 
 | |
|         # Extract patches using extract_patches
 | |
|         # This is equivalent to PyTorch's unfold operation
 | |
|         patch_x = tf.image.extract_patches(
 | |
|             x,
 | |
|             sizes=[1, self.patch_size, 1, 1],
 | |
|             strides=[1, self.patch_stride, 1, 1],
 | |
|             rates=[1, 1, 1, 1],
 | |
|             padding='VALID'
 | |
|         )
 | |
| 
 | |
|         # Reshape to match expected output
 | |
|         new_time_steps = tf.shape(patch_x)[1]
 | |
|         patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
 | |
| 
 | |
|         return patch_x
 | |
| 
 | |
| 
 | |
| class CleanSpeechModel(keras.Model):
 | |
|     """
 | |
|     Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
 | |
|     TensorFlow/Keras implementation optimized for TPU v5e-8
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  neural_dim,
 | |
|                  n_units,
 | |
|                  n_days,
 | |
|                  n_classes,
 | |
|                  rnn_dropout=0.0,
 | |
|                  input_dropout=0.0,
 | |
|                  patch_size=0,
 | |
|                  patch_stride=0,
 | |
|                  **kwargs):
 | |
|         super(CleanSpeechModel, self).__init__(**kwargs)
 | |
| 
 | |
|         self.neural_dim = neural_dim
 | |
|         self.n_units = n_units
 | |
|         self.n_days = n_days
 | |
|         self.n_classes = n_classes
 | |
|         self.rnn_dropout = rnn_dropout
 | |
|         self.input_dropout = input_dropout
 | |
|         self.patch_size = patch_size
 | |
|         self.patch_stride = patch_stride
 | |
| 
 | |
|         # Day-specific input layers
 | |
|         self.day_layer_activation = layers.Activation('softsign')
 | |
| 
 | |
|         # Initialize day-specific weights and biases
 | |
|         self.day_weights = []
 | |
|         self.day_biases = []
 | |
|         for i in range(n_days):
 | |
|             weight = self.add_weight(
 | |
|                 name=f'day_weight_{i}',
 | |
|                 shape=(neural_dim, neural_dim),
 | |
|                 initializer=tf.keras.initializers.Identity(),
 | |
|                 trainable=True
 | |
|             )
 | |
|             bias = self.add_weight(
 | |
|                 name=f'day_bias_{i}',
 | |
|                 shape=(neural_dim,),
 | |
|                 initializer=tf.keras.initializers.Zeros(),
 | |
|                 trainable=True
 | |
|             )
 | |
|             self.day_weights.append(weight)
 | |
|             self.day_biases.append(bias)
 | |
| 
 | |
|         self.day_layer_dropout = layers.Dropout(input_dropout)
 | |
| 
 | |
|         # Calculate input size after patching
 | |
|         self.input_size = self.neural_dim
 | |
|         if self.patch_size > 0:
 | |
|             self.input_size *= self.patch_size
 | |
| 
 | |
|         # 3-layer GRU for clean speech recognition
 | |
|         self.gru1 = layers.GRU(
 | |
|             units=n_units,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='clean_gru1'
 | |
|         )
 | |
| 
 | |
|         self.gru2 = layers.GRU(
 | |
|             units=n_units,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='clean_gru2'
 | |
|         )
 | |
| 
 | |
|         self.gru3 = layers.GRU(
 | |
|             units=n_units,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='clean_gru3'
 | |
|         )
 | |
| 
 | |
|         # Output classification layer
 | |
|         self.output_layer = layers.Dense(
 | |
|             n_classes,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             name='clean_output'
 | |
|         )
 | |
| 
 | |
|         # Learnable initial hidden states
 | |
|         self.h0_1 = self.add_weight(
 | |
|             name='h0_1',
 | |
|             shape=(1, n_units),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
|         self.h0_2 = self.add_weight(
 | |
|             name='h0_2',
 | |
|             shape=(1, n_units),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
|         self.h0_3 = self.add_weight(
 | |
|             name='h0_3',
 | |
|             shape=(1, n_units),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
| 
 | |
|     def call(self, x, day_idx, states=None, return_state=False, training=None):
 | |
|         """Forward pass optimized for TPU compilation"""
 | |
|         batch_size = tf.shape(x)[0]
 | |
| 
 | |
|         # Stack all day weights and biases for efficient gathering
 | |
|         all_day_weights = tf.stack(self.day_weights, axis=0)
 | |
|         all_day_biases = tf.stack(self.day_biases, axis=0)
 | |
| 
 | |
|         # Gather day-specific parameters
 | |
|         day_weights = tf.gather(all_day_weights, day_idx)
 | |
|         day_biases = tf.gather(all_day_biases, day_idx)
 | |
|         day_biases = tf.expand_dims(day_biases, axis=1)
 | |
| 
 | |
|         # Apply day-specific transformation
 | |
|         x = tf.linalg.matmul(x, day_weights) + day_biases
 | |
|         x = self.day_layer_activation(x)
 | |
| 
 | |
|         if training and self.input_dropout > 0:
 | |
|             x = self.day_layer_dropout(x, training=training)
 | |
| 
 | |
|         # Apply patch processing if enabled
 | |
|         if self.patch_size > 0:
 | |
|             x = self._apply_patch_processing(x)
 | |
| 
 | |
|         # Initialize hidden states if not provided
 | |
|         if states is None:
 | |
|             h1_init = tf.tile(self.h0_1, [batch_size, 1])
 | |
|             h2_init = tf.tile(self.h0_2, [batch_size, 1])
 | |
|             h3_init = tf.tile(self.h0_3, [batch_size, 1])
 | |
|             states = [h1_init, h2_init, h3_init]
 | |
|         else:
 | |
|             h1_init, h2_init, h3_init = states
 | |
| 
 | |
|         # Three-layer GRU forward pass
 | |
|         output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
 | |
|         output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
 | |
|         output, h3_final = self.gru3(output2, initial_state=h3_init, training=training)
 | |
| 
 | |
|         # Classification
 | |
|         logits = self.output_layer(output)
 | |
| 
 | |
|         if return_state:
 | |
|             return logits, [h1_final, h2_final, h3_final]
 | |
|         return logits
 | |
| 
 | |
|     def _apply_patch_processing(self, x):
 | |
|         """Apply patch processing using TensorFlow operations"""
 | |
|         batch_size = tf.shape(x)[0]
 | |
| 
 | |
|         # Add channel dimension
 | |
|         x = tf.expand_dims(x, axis=2)
 | |
| 
 | |
|         # Extract patches
 | |
|         patch_x = tf.image.extract_patches(
 | |
|             x,
 | |
|             sizes=[1, self.patch_size, 1, 1],
 | |
|             strides=[1, self.patch_stride, 1, 1],
 | |
|             rates=[1, 1, 1, 1],
 | |
|             padding='VALID'
 | |
|         )
 | |
| 
 | |
|         # Reshape
 | |
|         new_time_steps = tf.shape(patch_x)[1]
 | |
|         patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
 | |
| 
 | |
|         return patch_x
 | |
| 
 | |
| 
 | |
| class NoisySpeechModel(keras.Model):
 | |
|     """
 | |
|     Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
 | |
|     TensorFlow/Keras implementation optimized for TPU v5e-8
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  neural_dim,
 | |
|                  n_units,
 | |
|                  n_days,
 | |
|                  n_classes,
 | |
|                  rnn_dropout=0.0,
 | |
|                  input_dropout=0.0,
 | |
|                  patch_size=0,
 | |
|                  patch_stride=0,
 | |
|                  **kwargs):
 | |
|         super(NoisySpeechModel, self).__init__(**kwargs)
 | |
| 
 | |
|         self.neural_dim = neural_dim
 | |
|         self.n_units = n_units
 | |
|         self.n_days = n_days
 | |
|         self.n_classes = n_classes
 | |
|         self.rnn_dropout = rnn_dropout
 | |
|         self.input_dropout = input_dropout
 | |
|         self.patch_size = patch_size
 | |
|         self.patch_stride = patch_stride
 | |
| 
 | |
|         # Calculate input size after patching
 | |
|         self.input_size = self.neural_dim
 | |
|         if self.patch_size > 0:
 | |
|             self.input_size *= self.patch_size
 | |
| 
 | |
|         # 2-layer GRU for noisy speech recognition
 | |
|         self.gru1 = layers.GRU(
 | |
|             units=n_units,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='noisy_gru1'
 | |
|         )
 | |
| 
 | |
|         self.gru2 = layers.GRU(
 | |
|             units=n_units,
 | |
|             return_sequences=True,
 | |
|             return_state=True,
 | |
|             dropout=self.rnn_dropout,
 | |
|             recurrent_dropout=0.0,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             recurrent_initializer=tf.keras.initializers.Orthogonal(),
 | |
|             name='noisy_gru2'
 | |
|         )
 | |
| 
 | |
|         # Output classification layer
 | |
|         self.output_layer = layers.Dense(
 | |
|             n_classes,
 | |
|             kernel_initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             name='noisy_output'
 | |
|         )
 | |
| 
 | |
|         # Learnable initial hidden states
 | |
|         self.h0_1 = self.add_weight(
 | |
|             name='h0_1',
 | |
|             shape=(1, n_units),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
|         self.h0_2 = self.add_weight(
 | |
|             name='h0_2',
 | |
|             shape=(1, n_units),
 | |
|             initializer=tf.keras.initializers.GlorotUniform(),
 | |
|             trainable=True
 | |
|         )
 | |
| 
 | |
|     def call(self, x, states=None, return_state=False, training=None):
 | |
|         """Forward pass - no day-specific layers for noise processing"""
 | |
|         batch_size = tf.shape(x)[0]
 | |
| 
 | |
|         # Initialize hidden states if not provided
 | |
|         if states is None:
 | |
|             h1_init = tf.tile(self.h0_1, [batch_size, 1])
 | |
|             h2_init = tf.tile(self.h0_2, [batch_size, 1])
 | |
|             states = [h1_init, h2_init]
 | |
|         else:
 | |
|             h1_init, h2_init = states
 | |
| 
 | |
|         # Two-layer GRU forward pass
 | |
|         output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
 | |
|         output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
 | |
| 
 | |
|         # Classification
 | |
|         logits = self.output_layer(output)
 | |
| 
 | |
|         if return_state:
 | |
|             return logits, [h1_final, h2_final]
 | |
|         return logits
 | |
| 
 | |
| 
 | |
| class TripleGRUDecoder(keras.Model):
 | |
|     """
 | |
|     Three-model adversarial architecture for neural speech decoding
 | |
|     TensorFlow/Keras implementation optimized for TPU v5e-8
 | |
| 
 | |
|     Combines:
 | |
|     - NoiseModel: estimates noise in neural data
 | |
|     - CleanSpeechModel: processes denoised signal for recognition
 | |
|     - NoisySpeechModel: processes noise signal for recognition
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  neural_dim,
 | |
|                  n_units,
 | |
|                  n_days,
 | |
|                  n_classes,
 | |
|                  rnn_dropout=0.0,
 | |
|                  input_dropout=0.0,
 | |
|                  patch_size=0,
 | |
|                  patch_stride=0,
 | |
|                  **kwargs):
 | |
|         super(TripleGRUDecoder, self).__init__(**kwargs)
 | |
| 
 | |
|         self.neural_dim = neural_dim
 | |
|         self.n_units = n_units
 | |
|         self.n_classes = n_classes
 | |
|         self.n_days = n_days
 | |
|         self.rnn_dropout = rnn_dropout
 | |
|         self.input_dropout = input_dropout
 | |
|         self.patch_size = patch_size
 | |
|         self.patch_stride = patch_stride
 | |
| 
 | |
|         # Create the three models
 | |
|         self.noise_model = NoiseModel(
 | |
|             neural_dim=neural_dim,
 | |
|             n_units=n_units,
 | |
|             n_days=n_days,
 | |
|             rnn_dropout=rnn_dropout,
 | |
|             input_dropout=input_dropout,
 | |
|             patch_size=patch_size,
 | |
|             patch_stride=patch_stride,
 | |
|             name='noise_model'
 | |
|         )
 | |
| 
 | |
|         self.clean_speech_model = CleanSpeechModel(
 | |
|             neural_dim=neural_dim,
 | |
|             n_units=n_units,
 | |
|             n_days=n_days,
 | |
|             n_classes=n_classes,
 | |
|             rnn_dropout=rnn_dropout,
 | |
|             input_dropout=input_dropout,
 | |
|             patch_size=patch_size,
 | |
|             patch_stride=patch_stride,
 | |
|             name='clean_speech_model'
 | |
|         )
 | |
| 
 | |
|         self.noisy_speech_model = NoisySpeechModel(
 | |
|             neural_dim=neural_dim,
 | |
|             n_units=n_units,
 | |
|             n_days=n_days,
 | |
|             n_classes=n_classes,
 | |
|             rnn_dropout=rnn_dropout,
 | |
|             input_dropout=input_dropout,
 | |
|             patch_size=patch_size,
 | |
|             patch_stride=patch_stride,
 | |
|             name='noisy_speech_model'
 | |
|         )
 | |
| 
 | |
|         # Training mode flag
 | |
|         self.training_mode = 'full'  # 'full', 'inference'
 | |
| 
 | |
|     def _apply_preprocessing(self, x, day_idx):
 | |
|         """Apply preprocessing using clean speech model's day layers"""
 | |
|         batch_size = tf.shape(x)[0]
 | |
| 
 | |
|         # Stack all day weights and biases
 | |
|         all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0)
 | |
|         all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0)
 | |
| 
 | |
|         # Gather day-specific parameters
 | |
|         day_weights = tf.gather(all_day_weights, day_idx)
 | |
|         day_biases = tf.gather(all_day_biases, day_idx)
 | |
|         day_biases = tf.expand_dims(day_biases, axis=1)
 | |
| 
 | |
|         # Apply transformation
 | |
|         x_processed = tf.linalg.matmul(x, day_weights) + day_biases
 | |
|         x_processed = self.clean_speech_model.day_layer_activation(x_processed)
 | |
| 
 | |
|         # Apply patch processing if enabled
 | |
|         if self.patch_size > 0:
 | |
|             x_processed = self.clean_speech_model._apply_patch_processing(x_processed)
 | |
| 
 | |
|         return x_processed
 | |
| 
 | |
|     def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None):
 | |
|         """Forward pass for CleanSpeechModel with already processed input"""
 | |
|         batch_size = tf.shape(x_processed)[0]
 | |
| 
 | |
|         # Initialize hidden states if not provided
 | |
|         if states is None:
 | |
|             h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1])
 | |
|             h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1])
 | |
|             h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1])
 | |
|             states = [h1_init, h2_init, h3_init]
 | |
|         else:
 | |
|             h1_init, h2_init, h3_init = states
 | |
| 
 | |
|         # GRU forward pass (skip preprocessing since input is already processed)
 | |
|         output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
 | |
|         output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training)
 | |
|         output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training)
 | |
| 
 | |
|         # Classification
 | |
|         logits = self.clean_speech_model.output_layer(output)
 | |
|         return logits
 | |
| 
 | |
|     def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None):
 | |
|         """Forward pass for NoisySpeechModel with already processed input"""
 | |
|         batch_size = tf.shape(x_processed)[0]
 | |
| 
 | |
|         # Initialize hidden states if not provided
 | |
|         if states is None:
 | |
|             h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1])
 | |
|             h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1])
 | |
|             states = [h1_init, h2_init]
 | |
|         else:
 | |
|             h1_init, h2_init = states
 | |
| 
 | |
|         # GRU forward pass
 | |
|         output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
 | |
|         output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training)
 | |
| 
 | |
|         # Classification
 | |
|         logits = self.noisy_speech_model.output_layer(output)
 | |
|         return logits
 | |
| 
 | |
|     def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None):
 | |
|         """
 | |
|         Three-model adversarial forward pass optimized for TPU compilation
 | |
| 
 | |
|         Args:
 | |
|             x: Input tensor [batch_size, time_steps, neural_dim]
 | |
|             day_idx: Day indices [batch_size]
 | |
|             states: Dictionary with 'noise', 'clean', 'noisy' states or None
 | |
|             return_state: Whether to return hidden states
 | |
|             mode: 'full' for training (all three models), 'inference' for inference
 | |
|             grl_lambda: Gradient reversal strength for adversarial training
 | |
|             training: Training mode flag
 | |
|         """
 | |
| 
 | |
|         if mode == 'full':
 | |
|             # Training mode: run all three models
 | |
| 
 | |
|             # 1. Noise model estimates noise in the data
 | |
|             noise_output, noise_hidden = self.noise_model(
 | |
|                 x, day_idx,
 | |
|                 states['noise'] if states else None,
 | |
|                 training=training
 | |
|             )
 | |
| 
 | |
|             # 2. Apply preprocessing to get x in the same space as noise_output
 | |
|             x_processed = self._apply_preprocessing(x, day_idx)
 | |
| 
 | |
|             # 3. Clean speech model processes denoised signal
 | |
|             denoised_input = x_processed - noise_output  # Residual connection
 | |
|             clean_logits = self._clean_forward_with_processed_input(
 | |
|                 denoised_input, day_idx,
 | |
|                 states['clean'] if states else None,
 | |
|                 training=training
 | |
|             )
 | |
| 
 | |
|             # 4. Noisy speech model processes noise signal
 | |
|             # Apply Gradient Reversal Layer if specified
 | |
|             if grl_lambda > 0.0:
 | |
|                 noisy_input = gradient_reverse(noise_output, grl_lambda)
 | |
|             else:
 | |
|                 noisy_input = noise_output
 | |
| 
 | |
|             noisy_logits = self._noisy_forward_with_processed_input(
 | |
|                 noisy_input,
 | |
|                 states['noisy'] if states else None,
 | |
|                 training=training
 | |
|             )
 | |
| 
 | |
|             # Return results
 | |
|             if return_state:
 | |
|                 return (clean_logits, noisy_logits, noise_output), noise_hidden
 | |
|             return clean_logits, noisy_logits, noise_output
 | |
| 
 | |
|         elif mode == 'inference':
 | |
|             # Inference mode: only noise model + clean speech model
 | |
| 
 | |
|             # 1. Estimate noise
 | |
|             noise_output, noise_hidden = self.noise_model(
 | |
|                 x, day_idx,
 | |
|                 states['noise'] if states else None,
 | |
|                 training=training
 | |
|             )
 | |
| 
 | |
|             # 2. Apply preprocessing for residual connection
 | |
|             x_processed = self._apply_preprocessing(x, day_idx)
 | |
|             denoised_input = x_processed - noise_output
 | |
|             clean_logits = self._clean_forward_with_processed_input(
 | |
|                 denoised_input, day_idx,
 | |
|                 states['clean'] if states else None,
 | |
|                 training=training
 | |
|             )
 | |
| 
 | |
|             # Return results
 | |
|             if return_state:
 | |
|                 return clean_logits, noise_hidden
 | |
|             return clean_logits
 | |
| 
 | |
|         else:
 | |
|             raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
 | |
| 
 | |
|     def set_mode(self, mode):
 | |
|         """Set the operating mode"""
 | |
|         self.training_mode = mode
 | |
| 
 | |
| 
 | |
| # TPU Strategy Helper Functions
 | |
| def create_tpu_strategy():
 | |
|     """Create TPU strategy for distributed training on TPU v5e-8"""
 | |
|     import os
 | |
| 
 | |
|     print("🔍 Detecting TPU environment...")
 | |
| 
 | |
|     # Disable GPU to avoid CUDA conflicts in TPU environment
 | |
|     try:
 | |
|         print("🚫 Disabling GPU to prevent CUDA conflicts...")
 | |
|         tf.config.set_visible_devices([], 'GPU')
 | |
|         print("✅ GPU disabled successfully")
 | |
|     except Exception as e:
 | |
|         print(f"⚠️  Warning: Could not disable GPU: {e}")
 | |
| 
 | |
|     # Check for various TPU environment variables
 | |
|     tpu_address = None
 | |
|     tpu_name = None
 | |
| 
 | |
|     # Check common TPU environment variables
 | |
|     if 'COLAB_TPU_ADDR' in os.environ:
 | |
|         tpu_address = os.environ['COLAB_TPU_ADDR']
 | |
|         print(f"📍 Found Colab TPU address: {tpu_address}")
 | |
|     elif 'TPU_NAME' in os.environ:
 | |
|         tpu_name = os.environ['TPU_NAME']
 | |
|         print(f"📍 Found TPU name: {tpu_name}")
 | |
|     elif 'TPU_WORKER_ID' in os.environ:
 | |
|         # Kaggle TPU environment
 | |
|         worker_id = os.environ.get('TPU_WORKER_ID', '0')
 | |
|         tpu_address = f'grpc://10.0.0.2:8470'  # Default Kaggle TPU address
 | |
|         print(f"📍 Kaggle TPU detected, worker ID: {worker_id}, address: {tpu_address}")
 | |
| 
 | |
|     # Print all TPU-related environment variables for debugging
 | |
|     print("🔧 TPU environment variables:")
 | |
|     tpu_vars = {k: v for k, v in os.environ.items() if 'TPU' in k or 'COLAB' in k}
 | |
|     for key, value in tpu_vars.items():
 | |
|         print(f"   {key}={value}")
 | |
| 
 | |
|     try:
 | |
|         # Use official TPU initialization pattern (simplified and reliable)
 | |
|         print("🚀 Using official TensorFlow TPU initialization...")
 | |
| 
 | |
|         # Use your tested official TPU initialization code
 | |
|         resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
 | |
|         tf.config.experimental_connect_to_cluster(resolver)
 | |
|         # This is the TPU initialization code that has to be at the beginning.
 | |
|         tf.tpu.experimental.initialize_tpu_system(resolver)
 | |
| 
 | |
|         # Verify TPU devices (following official example)
 | |
|         tpu_devices = tf.config.list_logical_devices('TPU')
 | |
|         print("All devices: ", tpu_devices)
 | |
| 
 | |
|         if not tpu_devices:
 | |
|             raise RuntimeError("No TPU devices found!")
 | |
| 
 | |
|         print(f"✅ Found {len(tpu_devices)} TPU devices")
 | |
| 
 | |
|         # Create TPU strategy
 | |
|         print("🎯 Creating TPU strategy...")
 | |
|         strategy = tf.distribute.TPUStrategy(resolver)
 | |
| 
 | |
|         print(f"✅ TPU initialized successfully!")
 | |
|         print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
 | |
| 
 | |
|         return strategy
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"❌ Failed to initialize TPU: {e}")
 | |
|         print(f"🔍 Error type: {type(e).__name__}")
 | |
| 
 | |
|         # Enhanced error reporting
 | |
|         if "Please provide a TPU Name" in str(e):
 | |
|             print("💡 Hint: TPU name/address not found in environment variables")
 | |
|             print("   Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
 | |
| 
 | |
|         print("🔄 Falling back to default strategy (CPU/GPU)")
 | |
|         fallback_strategy = tf.distribute.get_strategy()
 | |
|         print(f"🎯 Fallback strategy created: {type(fallback_strategy).__name__}")
 | |
|         print(f"📊 Fallback strategy replicas: {fallback_strategy.num_replicas_in_sync}")
 | |
| 
 | |
|         # Ensure we never return None
 | |
|         if fallback_strategy is None:
 | |
|             print("⚠️  Warning: Default strategy is None, creating OneDeviceStrategy")
 | |
|             fallback_strategy = tf.distribute.OneDeviceStrategy("/CPU:0")
 | |
| 
 | |
|         return fallback_strategy
 | |
| 
 | |
| 
 | |
| def build_model_for_tpu(config):
 | |
|     """
 | |
|     Build TripleGRUDecoder model optimized for TPU v5e-8
 | |
| 
 | |
|     Args:
 | |
|         config: Configuration dictionary containing model parameters
 | |
| 
 | |
|     Returns:
 | |
|         Compiled Keras model ready for TPU training
 | |
|     """
 | |
|     model = TripleGRUDecoder(
 | |
|         neural_dim=config['model']['n_input_features'],
 | |
|         n_units=config['model']['n_units'],
 | |
|         n_days=len(config['dataset']['sessions']),
 | |
|         n_classes=config['dataset']['n_classes'],
 | |
|         rnn_dropout=config['model']['rnn_dropout'],
 | |
|         input_dropout=config['model']['input_network']['input_layer_dropout'],
 | |
|         patch_size=config['model']['patch_size'],
 | |
|         patch_stride=config['model']['patch_stride']
 | |
|     )
 | |
| 
 | |
|     return model
 | |
| 
 | |
| 
 | |
| # Mixed Precision Configuration for TPU v5e-8
 | |
| def configure_mixed_precision():
 | |
|     """Configure mixed precision for optimal TPU v5e-8 performance"""
 | |
|     policy = keras.mixed_precision.Policy('mixed_bfloat16')
 | |
|     keras.mixed_precision.set_global_policy(policy)
 | |
|     print(f"Mixed precision policy set to: {policy.name}")
 | |
|     return policy | 
