871 lines
30 KiB
Python
871 lines
30 KiB
Python
import tensorflow as tf
|
|
from tensorflow import keras
|
|
from tensorflow.keras import layers
|
|
import numpy as np
|
|
|
|
|
|
@tf.custom_gradient
|
|
def gradient_reverse(x, lambd=1.0):
|
|
"""
|
|
Gradient Reversal Layer (GRL) for TensorFlow
|
|
Forward: identity
|
|
Backward: multiply incoming gradient by -lambda
|
|
"""
|
|
def grad(dy):
|
|
return -lambd * dy # Only return gradient w.r.t. x, not lambd
|
|
|
|
return tf.identity(x), grad
|
|
|
|
|
|
class NoiseModel(keras.Model):
|
|
"""
|
|
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
|
TensorFlow/Keras implementation optimized for TPU v5e-8
|
|
"""
|
|
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0,
|
|
**kwargs):
|
|
super(NoiseModel, self).__init__(**kwargs)
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Day-specific input layers - use Variables for TPU compatibility
|
|
self.day_layer_activation = layers.Activation('softsign')
|
|
|
|
# Initialize day-specific weights and biases as Variables
|
|
self.day_weights = []
|
|
self.day_biases = []
|
|
for i in range(n_days):
|
|
weight = self.add_weight(
|
|
name=f'day_weight_{i}',
|
|
shape=(neural_dim, neural_dim),
|
|
initializer=tf.keras.initializers.Identity(),
|
|
trainable=True
|
|
)
|
|
bias = self.add_weight(
|
|
name=f'day_bias_{i}',
|
|
shape=(neural_dim,),
|
|
initializer=tf.keras.initializers.Zeros(),
|
|
trainable=True
|
|
)
|
|
self.day_weights.append(weight)
|
|
self.day_biases.append(bias)
|
|
|
|
self.day_layer_dropout = layers.Dropout(input_dropout)
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 2-layer GRU for noise estimation
|
|
# Use separate GRU layers for better TPU performance
|
|
self.gru1 = layers.GRU(
|
|
units=self.input_size,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0, # Avoid recurrent dropout on TPU
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='noise_gru1'
|
|
)
|
|
|
|
self.gru2 = layers.GRU(
|
|
units=self.input_size,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='noise_gru2'
|
|
)
|
|
|
|
# Learnable initial hidden states
|
|
self.h0_1 = self.add_weight(
|
|
name='h0_1',
|
|
shape=(1, self.input_size),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
self.h0_2 = self.add_weight(
|
|
name='h0_2',
|
|
shape=(1, self.input_size),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
|
|
def call(self, x, day_idx, states=None, training=None):
|
|
"""
|
|
Forward pass optimized for TPU compilation
|
|
|
|
Args:
|
|
x: Input tensor [batch_size, time_steps, neural_dim]
|
|
day_idx: Day indices [batch_size]
|
|
states: Optional initial states
|
|
training: Training mode flag
|
|
"""
|
|
batch_size = tf.shape(x)[0]
|
|
|
|
# Stack all day weights and biases for efficient gathering
|
|
all_day_weights = tf.stack(self.day_weights, axis=0) # [n_days, neural_dim, neural_dim]
|
|
all_day_biases = tf.stack(self.day_biases, axis=0) # [n_days, neural_dim]
|
|
|
|
# Gather day-specific parameters
|
|
day_weights = tf.gather(all_day_weights, day_idx) # [batch_size, neural_dim, neural_dim]
|
|
day_biases = tf.gather(all_day_biases, day_idx) # [batch_size, neural_dim]
|
|
|
|
# Add time dimension to biases for broadcasting
|
|
day_biases = tf.expand_dims(day_biases, axis=1) # [batch_size, 1, neural_dim]
|
|
|
|
# Apply day-specific transformation using efficient batch matrix multiplication
|
|
x = tf.linalg.matmul(x, day_weights) + day_biases
|
|
x = self.day_layer_activation(x)
|
|
|
|
# Apply input dropout
|
|
if training and self.input_dropout > 0:
|
|
x = self.day_layer_dropout(x, training=training)
|
|
|
|
# Apply patch processing if enabled
|
|
if self.patch_size > 0:
|
|
x = self._apply_patch_processing(x)
|
|
|
|
# Initialize hidden states if not provided
|
|
if states is None:
|
|
h1_init = tf.tile(self.h0_1, [batch_size, 1]) # [batch_size, input_size]
|
|
h2_init = tf.tile(self.h0_2, [batch_size, 1]) # [batch_size, input_size]
|
|
states = [h1_init, h2_init]
|
|
else:
|
|
h1_init, h2_init = states
|
|
|
|
# Two-layer GRU forward pass
|
|
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
|
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
|
|
|
return output, [h1_final, h2_final]
|
|
|
|
def _apply_patch_processing(self, x):
|
|
"""Apply patch processing using TensorFlow operations"""
|
|
batch_size = tf.shape(x)[0]
|
|
time_steps = tf.shape(x)[1]
|
|
|
|
# Add channel dimension for conv1d operations
|
|
x = tf.expand_dims(x, axis=2) # [batch_size, time_steps, 1, neural_dim]
|
|
|
|
# Extract patches using extract_patches
|
|
# This is equivalent to PyTorch's unfold operation
|
|
patch_x = tf.image.extract_patches(
|
|
x,
|
|
sizes=[1, self.patch_size, 1, 1],
|
|
strides=[1, self.patch_stride, 1, 1],
|
|
rates=[1, 1, 1, 1],
|
|
padding='VALID'
|
|
)
|
|
|
|
# Reshape to match expected output
|
|
new_time_steps = tf.shape(patch_x)[1]
|
|
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
|
|
|
return patch_x
|
|
|
|
|
|
class CleanSpeechModel(keras.Model):
|
|
"""
|
|
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
|
TensorFlow/Keras implementation optimized for TPU v5e-8
|
|
"""
|
|
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0,
|
|
**kwargs):
|
|
super(CleanSpeechModel, self).__init__(**kwargs)
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.n_classes = n_classes
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Day-specific input layers
|
|
self.day_layer_activation = layers.Activation('softsign')
|
|
|
|
# Initialize day-specific weights and biases
|
|
self.day_weights = []
|
|
self.day_biases = []
|
|
for i in range(n_days):
|
|
weight = self.add_weight(
|
|
name=f'day_weight_{i}',
|
|
shape=(neural_dim, neural_dim),
|
|
initializer=tf.keras.initializers.Identity(),
|
|
trainable=True
|
|
)
|
|
bias = self.add_weight(
|
|
name=f'day_bias_{i}',
|
|
shape=(neural_dim,),
|
|
initializer=tf.keras.initializers.Zeros(),
|
|
trainable=True
|
|
)
|
|
self.day_weights.append(weight)
|
|
self.day_biases.append(bias)
|
|
|
|
self.day_layer_dropout = layers.Dropout(input_dropout)
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 3-layer GRU for clean speech recognition
|
|
self.gru1 = layers.GRU(
|
|
units=n_units,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='clean_gru1'
|
|
)
|
|
|
|
self.gru2 = layers.GRU(
|
|
units=n_units,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='clean_gru2'
|
|
)
|
|
|
|
self.gru3 = layers.GRU(
|
|
units=n_units,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='clean_gru3'
|
|
)
|
|
|
|
# Output classification layer
|
|
self.output_layer = layers.Dense(
|
|
n_classes,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
name='clean_output'
|
|
)
|
|
|
|
# Learnable initial hidden states
|
|
self.h0_1 = self.add_weight(
|
|
name='h0_1',
|
|
shape=(1, n_units),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
self.h0_2 = self.add_weight(
|
|
name='h0_2',
|
|
shape=(1, n_units),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
self.h0_3 = self.add_weight(
|
|
name='h0_3',
|
|
shape=(1, n_units),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
|
|
def call(self, x, day_idx, states=None, return_state=False, training=None):
|
|
"""Forward pass optimized for TPU compilation"""
|
|
batch_size = tf.shape(x)[0]
|
|
|
|
# Stack all day weights and biases for efficient gathering
|
|
all_day_weights = tf.stack(self.day_weights, axis=0)
|
|
all_day_biases = tf.stack(self.day_biases, axis=0)
|
|
|
|
# Gather day-specific parameters
|
|
day_weights = tf.gather(all_day_weights, day_idx)
|
|
day_biases = tf.gather(all_day_biases, day_idx)
|
|
day_biases = tf.expand_dims(day_biases, axis=1)
|
|
|
|
# Apply day-specific transformation
|
|
x = tf.linalg.matmul(x, day_weights) + day_biases
|
|
x = self.day_layer_activation(x)
|
|
|
|
if training and self.input_dropout > 0:
|
|
x = self.day_layer_dropout(x, training=training)
|
|
|
|
# Apply patch processing if enabled
|
|
if self.patch_size > 0:
|
|
x = self._apply_patch_processing(x)
|
|
|
|
# Initialize hidden states if not provided
|
|
if states is None:
|
|
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
|
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
|
h3_init = tf.tile(self.h0_3, [batch_size, 1])
|
|
states = [h1_init, h2_init, h3_init]
|
|
else:
|
|
h1_init, h2_init, h3_init = states
|
|
|
|
# Three-layer GRU forward pass
|
|
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
|
output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
|
output, h3_final = self.gru3(output2, initial_state=h3_init, training=training)
|
|
|
|
# Classification
|
|
logits = self.output_layer(output)
|
|
|
|
if return_state:
|
|
return logits, [h1_final, h2_final, h3_final]
|
|
return logits
|
|
|
|
def _apply_patch_processing(self, x):
|
|
"""Apply patch processing using TensorFlow operations"""
|
|
batch_size = tf.shape(x)[0]
|
|
|
|
# Add channel dimension
|
|
x = tf.expand_dims(x, axis=2)
|
|
|
|
# Extract patches
|
|
patch_x = tf.image.extract_patches(
|
|
x,
|
|
sizes=[1, self.patch_size, 1, 1],
|
|
strides=[1, self.patch_stride, 1, 1],
|
|
rates=[1, 1, 1, 1],
|
|
padding='VALID'
|
|
)
|
|
|
|
# Reshape
|
|
new_time_steps = tf.shape(patch_x)[1]
|
|
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
|
|
|
return patch_x
|
|
|
|
|
|
class NoisySpeechModel(keras.Model):
|
|
"""
|
|
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
|
TensorFlow/Keras implementation optimized for TPU v5e-8
|
|
"""
|
|
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0,
|
|
**kwargs):
|
|
super(NoisySpeechModel, self).__init__(**kwargs)
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.n_classes = n_classes
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 2-layer GRU for noisy speech recognition
|
|
self.gru1 = layers.GRU(
|
|
units=n_units,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='noisy_gru1'
|
|
)
|
|
|
|
self.gru2 = layers.GRU(
|
|
units=n_units,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
dropout=self.rnn_dropout,
|
|
recurrent_dropout=0.0,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
recurrent_initializer=tf.keras.initializers.Orthogonal(),
|
|
name='noisy_gru2'
|
|
)
|
|
|
|
# Output classification layer
|
|
self.output_layer = layers.Dense(
|
|
n_classes,
|
|
kernel_initializer=tf.keras.initializers.GlorotUniform(),
|
|
name='noisy_output'
|
|
)
|
|
|
|
# Learnable initial hidden states
|
|
self.h0_1 = self.add_weight(
|
|
name='h0_1',
|
|
shape=(1, n_units),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
self.h0_2 = self.add_weight(
|
|
name='h0_2',
|
|
shape=(1, n_units),
|
|
initializer=tf.keras.initializers.GlorotUniform(),
|
|
trainable=True
|
|
)
|
|
|
|
def call(self, x, states=None, return_state=False, training=None):
|
|
"""Forward pass - no day-specific layers for noise processing"""
|
|
batch_size = tf.shape(x)[0]
|
|
|
|
# Initialize hidden states if not provided
|
|
if states is None:
|
|
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
|
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
|
states = [h1_init, h2_init]
|
|
else:
|
|
h1_init, h2_init = states
|
|
|
|
# Two-layer GRU forward pass
|
|
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
|
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
|
|
|
# Classification
|
|
logits = self.output_layer(output)
|
|
|
|
if return_state:
|
|
return logits, [h1_final, h2_final]
|
|
return logits
|
|
|
|
|
|
class TripleGRUDecoder(keras.Model):
|
|
"""
|
|
Three-model adversarial architecture for neural speech decoding
|
|
TensorFlow/Keras implementation optimized for TPU v5e-8
|
|
|
|
Combines:
|
|
- NoiseModel: estimates noise in neural data
|
|
- CleanSpeechModel: processes denoised signal for recognition
|
|
- NoisySpeechModel: processes noise signal for recognition
|
|
"""
|
|
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0,
|
|
**kwargs):
|
|
super(TripleGRUDecoder, self).__init__(**kwargs)
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_classes = n_classes
|
|
self.n_days = n_days
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Create the three models
|
|
self.noise_model = NoiseModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride,
|
|
name='noise_model'
|
|
)
|
|
|
|
self.clean_speech_model = CleanSpeechModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride,
|
|
name='clean_speech_model'
|
|
)
|
|
|
|
self.noisy_speech_model = NoisySpeechModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride,
|
|
name='noisy_speech_model'
|
|
)
|
|
|
|
# Training mode flag
|
|
self.training_mode = 'full' # 'full', 'inference'
|
|
|
|
def _apply_preprocessing(self, x, day_idx):
|
|
"""Apply preprocessing using clean speech model's day layers"""
|
|
batch_size = tf.shape(x)[0]
|
|
|
|
# Stack all day weights and biases
|
|
all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0)
|
|
all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0)
|
|
|
|
# Gather day-specific parameters
|
|
day_weights = tf.gather(all_day_weights, day_idx)
|
|
day_biases = tf.gather(all_day_biases, day_idx)
|
|
day_biases = tf.expand_dims(day_biases, axis=1)
|
|
|
|
# Apply transformation
|
|
x_processed = tf.linalg.matmul(x, day_weights) + day_biases
|
|
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
|
|
|
# Apply patch processing if enabled
|
|
if self.patch_size > 0:
|
|
x_processed = self.clean_speech_model._apply_patch_processing(x_processed)
|
|
|
|
return x_processed
|
|
|
|
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None):
|
|
"""Forward pass for CleanSpeechModel with already processed input"""
|
|
batch_size = tf.shape(x_processed)[0]
|
|
|
|
# Initialize hidden states if not provided
|
|
if states is None:
|
|
h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1])
|
|
h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1])
|
|
h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1])
|
|
states = [h1_init, h2_init, h3_init]
|
|
else:
|
|
h1_init, h2_init, h3_init = states
|
|
|
|
# GRU forward pass (skip preprocessing since input is already processed)
|
|
output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
|
output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
|
output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training)
|
|
|
|
# Classification
|
|
logits = self.clean_speech_model.output_layer(output)
|
|
return logits
|
|
|
|
def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None):
|
|
"""Forward pass for NoisySpeechModel with already processed input"""
|
|
batch_size = tf.shape(x_processed)[0]
|
|
|
|
# Initialize hidden states if not provided
|
|
if states is None:
|
|
h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1])
|
|
h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1])
|
|
states = [h1_init, h2_init]
|
|
else:
|
|
h1_init, h2_init = states
|
|
|
|
# GRU forward pass
|
|
output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
|
output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
|
|
|
# Classification
|
|
logits = self.noisy_speech_model.output_layer(output)
|
|
return logits
|
|
|
|
def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None):
|
|
"""
|
|
Three-model adversarial forward pass optimized for TPU compilation
|
|
|
|
Args:
|
|
x: Input tensor [batch_size, time_steps, neural_dim]
|
|
day_idx: Day indices [batch_size]
|
|
states: Dictionary with 'noise', 'clean', 'noisy' states or None
|
|
return_state: Whether to return hidden states
|
|
mode: 'full' for training (all three models), 'inference' for inference
|
|
grl_lambda: Gradient reversal strength for adversarial training
|
|
training: Training mode flag
|
|
"""
|
|
|
|
if mode == 'full':
|
|
# Training mode: run all three models
|
|
|
|
# 1. Noise model estimates noise in the data
|
|
noise_output, noise_hidden = self.noise_model(
|
|
x, day_idx,
|
|
states['noise'] if states else None,
|
|
training=training
|
|
)
|
|
|
|
# 2. Apply preprocessing to get x in the same space as noise_output
|
|
x_processed = self._apply_preprocessing(x, day_idx)
|
|
|
|
# 3. Clean speech model processes denoised signal
|
|
denoised_input = x_processed - noise_output # Residual connection
|
|
clean_logits = self._clean_forward_with_processed_input(
|
|
denoised_input, day_idx,
|
|
states['clean'] if states else None,
|
|
training=training
|
|
)
|
|
|
|
# 4. Noisy speech model processes noise signal
|
|
# Apply Gradient Reversal Layer if specified
|
|
if grl_lambda > 0.0:
|
|
noisy_input = gradient_reverse(noise_output, grl_lambda)
|
|
else:
|
|
noisy_input = noise_output
|
|
|
|
noisy_logits = self._noisy_forward_with_processed_input(
|
|
noisy_input,
|
|
states['noisy'] if states else None,
|
|
training=training
|
|
)
|
|
|
|
# Return results
|
|
if return_state:
|
|
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
|
return clean_logits, noisy_logits, noise_output
|
|
|
|
elif mode == 'inference':
|
|
# Inference mode: only noise model + clean speech model
|
|
|
|
# 1. Estimate noise
|
|
noise_output, noise_hidden = self.noise_model(
|
|
x, day_idx,
|
|
states['noise'] if states else None,
|
|
training=training
|
|
)
|
|
|
|
# 2. Apply preprocessing for residual connection
|
|
x_processed = self._apply_preprocessing(x, day_idx)
|
|
denoised_input = x_processed - noise_output
|
|
clean_logits = self._clean_forward_with_processed_input(
|
|
denoised_input, day_idx,
|
|
states['clean'] if states else None,
|
|
training=training
|
|
)
|
|
|
|
# Return results
|
|
if return_state:
|
|
return clean_logits, noise_hidden
|
|
return clean_logits
|
|
|
|
else:
|
|
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
|
|
|
def set_mode(self, mode):
|
|
"""Set the operating mode"""
|
|
self.training_mode = mode
|
|
|
|
|
|
# Custom CTC Loss for TensorFlow TPU
|
|
class CTCLoss(keras.losses.Loss):
|
|
"""
|
|
Custom CTC Loss optimized for TPU v5e-8
|
|
"""
|
|
|
|
def __init__(self, blank_index=0, reduction='none', **kwargs):
|
|
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
|
|
self.blank_index = blank_index
|
|
|
|
def call(self, y_true, y_pred):
|
|
"""
|
|
Args:
|
|
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
|
|
y_pred: Logits tensor [batch_size, time_steps, num_classes]
|
|
"""
|
|
labels = y_true['labels']
|
|
input_lengths = y_true['input_lengths']
|
|
label_lengths = y_true['label_lengths']
|
|
|
|
# Ensure correct data types
|
|
labels = tf.cast(labels, tf.int32)
|
|
input_lengths = tf.cast(input_lengths, tf.int32)
|
|
label_lengths = tf.cast(label_lengths, tf.int32)
|
|
|
|
# Convert logits to log probabilities
|
|
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
|
|
|
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
|
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
|
|
|
# Convert dense labels to sparse format for CTC using TensorFlow operations
|
|
def dense_to_sparse(dense_tensor, sequence_lengths):
|
|
"""Convert dense tensor to sparse tensor for CTC"""
|
|
batch_size = tf.shape(dense_tensor)[0]
|
|
max_len = tf.shape(dense_tensor)[1]
|
|
|
|
# Create mask for non-zero elements
|
|
mask = tf.not_equal(dense_tensor, 0)
|
|
|
|
# Get indices of non-zero elements
|
|
indices = tf.where(mask)
|
|
|
|
# Get values at those indices
|
|
values = tf.gather_nd(dense_tensor, indices)
|
|
|
|
# Create sparse tensor
|
|
dense_shape = tf.cast([batch_size, max_len], tf.int64)
|
|
|
|
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
|
|
|
# Convert labels to sparse format
|
|
sparse_labels = dense_to_sparse(labels, label_lengths)
|
|
|
|
# Compute CTC loss
|
|
loss = tf.nn.ctc_loss(
|
|
labels=sparse_labels,
|
|
logits=log_probs,
|
|
label_length=None, # Not needed for sparse format
|
|
logit_length=input_lengths,
|
|
blank_index=self.blank_index,
|
|
logits_time_major=True
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
# TPU Strategy Helper Functions
|
|
def create_tpu_strategy():
|
|
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
|
import os
|
|
|
|
print("🔍 Detecting TPU environment...")
|
|
|
|
# Disable GPU to avoid CUDA conflicts in TPU environment
|
|
try:
|
|
print("🚫 Disabling GPU to prevent CUDA conflicts...")
|
|
tf.config.set_visible_devices([], 'GPU')
|
|
print("✅ GPU disabled successfully")
|
|
except Exception as e:
|
|
print(f"⚠️ Warning: Could not disable GPU: {e}")
|
|
|
|
# Check for various TPU environment variables
|
|
tpu_address = None
|
|
tpu_name = None
|
|
|
|
# Check common TPU environment variables
|
|
if 'COLAB_TPU_ADDR' in os.environ:
|
|
tpu_address = os.environ['COLAB_TPU_ADDR']
|
|
print(f"📍 Found Colab TPU address: {tpu_address}")
|
|
elif 'TPU_NAME' in os.environ:
|
|
tpu_name = os.environ['TPU_NAME']
|
|
print(f"📍 Found TPU name: {tpu_name}")
|
|
elif 'TPU_WORKER_ID' in os.environ:
|
|
# Kaggle TPU environment
|
|
worker_id = os.environ.get('TPU_WORKER_ID', '0')
|
|
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
|
|
print(f"📍 Kaggle TPU detected, worker ID: {worker_id}, address: {tpu_address}")
|
|
|
|
# Print all TPU-related environment variables for debugging
|
|
print("🔧 TPU environment variables:")
|
|
tpu_vars = {k: v for k, v in os.environ.items() if 'TPU' in k or 'COLAB' in k}
|
|
for key, value in tpu_vars.items():
|
|
print(f" {key}={value}")
|
|
|
|
try:
|
|
# Try different TPU resolver configurations
|
|
resolver = None
|
|
|
|
if tpu_address:
|
|
print(f"🚀 Attempting TPU connection with address: {tpu_address}")
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
|
|
elif tpu_name:
|
|
print(f"🚀 Attempting TPU connection with name: {tpu_name}")
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
|
|
else:
|
|
# Try auto-detection
|
|
print("🚀 Attempting TPU auto-detection...")
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
|
|
|
# Initialize TPU
|
|
print("⚡ Connecting to TPU cluster...")
|
|
tf.config.experimental_connect_to_cluster(resolver)
|
|
|
|
print("🔧 Initializing TPU system...")
|
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
|
|
# Create TPU strategy
|
|
print("🎯 Creating TPU strategy...")
|
|
strategy = tf.distribute.TPUStrategy(resolver)
|
|
|
|
print(f"✅ TPU initialized successfully!")
|
|
print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
|
|
print(f"🏃 TPU cluster: {resolver.cluster_spec()}")
|
|
return strategy
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to initialize TPU: {e}")
|
|
print(f"🔍 Error type: {type(e).__name__}")
|
|
|
|
# Enhanced error reporting
|
|
if "Please provide a TPU Name" in str(e):
|
|
print("💡 Hint: TPU name/address not found in environment variables")
|
|
print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
|
|
|
|
print("🔄 Falling back to default strategy (CPU/GPU)")
|
|
return tf.distribute.get_strategy()
|
|
|
|
|
|
def build_model_for_tpu(config):
|
|
"""
|
|
Build TripleGRUDecoder model optimized for TPU v5e-8
|
|
|
|
Args:
|
|
config: Configuration dictionary containing model parameters
|
|
|
|
Returns:
|
|
Compiled Keras model ready for TPU training
|
|
"""
|
|
model = TripleGRUDecoder(
|
|
neural_dim=config['model']['n_input_features'],
|
|
n_units=config['model']['n_units'],
|
|
n_days=len(config['dataset']['sessions']),
|
|
n_classes=config['dataset']['n_classes'],
|
|
rnn_dropout=config['model']['rnn_dropout'],
|
|
input_dropout=config['model']['input_network']['input_layer_dropout'],
|
|
patch_size=config['model']['patch_size'],
|
|
patch_stride=config['model']['patch_stride']
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
# Mixed Precision Configuration for TPU v5e-8
|
|
def configure_mixed_precision():
|
|
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
|
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
|
keras.mixed_precision.set_global_policy(policy)
|
|
print(f"Mixed precision policy set to: {policy.name}")
|
|
return policy |