TPU
This commit is contained in:
781
model_training_nnn_tpu/rnn_model_tf.py
Normal file
781
model_training_nnn_tpu/rnn_model_tf.py
Normal file
@@ -0,0 +1,781 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
import numpy as np
|
||||
|
||||
|
||||
@tf.custom_gradient
|
||||
def gradient_reverse(x, lambd=1.0):
|
||||
"""
|
||||
Gradient Reversal Layer (GRL) for TensorFlow
|
||||
Forward: identity
|
||||
Backward: multiply incoming gradient by -lambda
|
||||
"""
|
||||
def grad(dy):
|
||||
return -lambd * dy, None
|
||||
|
||||
return tf.identity(x), grad
|
||||
|
||||
|
||||
class NoiseModel(keras.Model):
|
||||
"""
|
||||
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(NoiseModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers - use Variables for TPU compatibility
|
||||
self.day_layer_activation = layers.Activation('softsign')
|
||||
|
||||
# Initialize day-specific weights and biases as Variables
|
||||
self.day_weights = []
|
||||
self.day_biases = []
|
||||
for i in range(n_days):
|
||||
weight = self.add_weight(
|
||||
name=f'day_weight_{i}',
|
||||
shape=(neural_dim, neural_dim),
|
||||
initializer='identity',
|
||||
trainable=True
|
||||
)
|
||||
bias = self.add_weight(
|
||||
name=f'day_bias_{i}',
|
||||
shape=(neural_dim,),
|
||||
initializer='zeros',
|
||||
trainable=True
|
||||
)
|
||||
self.day_weights.append(weight)
|
||||
self.day_biases.append(bias)
|
||||
|
||||
self.day_layer_dropout = layers.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noise estimation
|
||||
# Use separate GRU layers for better TPU performance
|
||||
self.gru1 = layers.GRU(
|
||||
units=self.input_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0, # Avoid recurrent dropout on TPU
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noise_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=self.input_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noise_gru2'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, self.input_size),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, self.input_size),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, day_idx, states=None, training=None):
|
||||
"""
|
||||
Forward pass optimized for TPU compilation
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, time_steps, neural_dim]
|
||||
day_idx: Day indices [batch_size]
|
||||
states: Optional initial states
|
||||
training: Training mode flag
|
||||
"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases for efficient gathering
|
||||
all_day_weights = tf.stack(self.day_weights, axis=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = tf.stack(self.day_biases, axis=0) # [n_days, neural_dim]
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = tf.gather(all_day_biases, day_idx) # [batch_size, neural_dim]
|
||||
|
||||
# Add time dimension to biases for broadcasting
|
||||
day_biases = tf.expand_dims(day_biases, axis=1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Apply day-specific transformation using efficient batch matrix multiplication
|
||||
x = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# Apply input dropout
|
||||
if training and self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x, training=training)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x = self._apply_patch_processing(x)
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1]) # [batch_size, input_size]
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1]) # [batch_size, input_size]
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# Two-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
return output, [h1_final, h2_final]
|
||||
|
||||
def _apply_patch_processing(self, x):
|
||||
"""Apply patch processing using TensorFlow operations"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
time_steps = tf.shape(x)[1]
|
||||
|
||||
# Add channel dimension for conv1d operations
|
||||
x = tf.expand_dims(x, axis=2) # [batch_size, time_steps, 1, neural_dim]
|
||||
|
||||
# Extract patches using extract_patches
|
||||
# This is equivalent to PyTorch's unfold operation
|
||||
patch_x = tf.image.extract_patches(
|
||||
x,
|
||||
sizes=[1, self.patch_size, 1, 1],
|
||||
strides=[1, self.patch_stride, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='VALID'
|
||||
)
|
||||
|
||||
# Reshape to match expected output
|
||||
new_time_steps = tf.shape(patch_x)[1]
|
||||
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
||||
|
||||
return patch_x
|
||||
|
||||
|
||||
class CleanSpeechModel(keras.Model):
|
||||
"""
|
||||
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(CleanSpeechModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = layers.Activation('softsign')
|
||||
|
||||
# Initialize day-specific weights and biases
|
||||
self.day_weights = []
|
||||
self.day_biases = []
|
||||
for i in range(n_days):
|
||||
weight = self.add_weight(
|
||||
name=f'day_weight_{i}',
|
||||
shape=(neural_dim, neural_dim),
|
||||
initializer='identity',
|
||||
trainable=True
|
||||
)
|
||||
bias = self.add_weight(
|
||||
name=f'day_bias_{i}',
|
||||
shape=(neural_dim,),
|
||||
initializer='zeros',
|
||||
trainable=True
|
||||
)
|
||||
self.day_weights.append(weight)
|
||||
self.day_biases.append(bias)
|
||||
|
||||
self.day_layer_dropout = layers.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 3-layer GRU for clean speech recognition
|
||||
self.gru1 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru2'
|
||||
)
|
||||
|
||||
self.gru3 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru3'
|
||||
)
|
||||
|
||||
# Output classification layer
|
||||
self.output_layer = layers.Dense(
|
||||
n_classes,
|
||||
kernel_initializer='glorot_uniform',
|
||||
name='clean_output'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_3 = self.add_weight(
|
||||
name='h0_3',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, day_idx, states=None, return_state=False, training=None):
|
||||
"""Forward pass optimized for TPU compilation"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases for efficient gathering
|
||||
all_day_weights = tf.stack(self.day_weights, axis=0)
|
||||
all_day_biases = tf.stack(self.day_biases, axis=0)
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx)
|
||||
day_biases = tf.gather(all_day_biases, day_idx)
|
||||
day_biases = tf.expand_dims(day_biases, axis=1)
|
||||
|
||||
# Apply day-specific transformation
|
||||
x = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
if training and self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x, training=training)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x = self._apply_patch_processing(x)
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
||||
h3_init = tf.tile(self.h0_3, [batch_size, 1])
|
||||
states = [h1_init, h2_init, h3_init]
|
||||
else:
|
||||
h1_init, h2_init, h3_init = states
|
||||
|
||||
# Three-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
output, h3_final = self.gru3(output2, initial_state=h3_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.output_layer(output)
|
||||
|
||||
if return_state:
|
||||
return logits, [h1_final, h2_final, h3_final]
|
||||
return logits
|
||||
|
||||
def _apply_patch_processing(self, x):
|
||||
"""Apply patch processing using TensorFlow operations"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Add channel dimension
|
||||
x = tf.expand_dims(x, axis=2)
|
||||
|
||||
# Extract patches
|
||||
patch_x = tf.image.extract_patches(
|
||||
x,
|
||||
sizes=[1, self.patch_size, 1, 1],
|
||||
strides=[1, self.patch_stride, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='VALID'
|
||||
)
|
||||
|
||||
# Reshape
|
||||
new_time_steps = tf.shape(patch_x)[1]
|
||||
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
||||
|
||||
return patch_x
|
||||
|
||||
|
||||
class NoisySpeechModel(keras.Model):
|
||||
"""
|
||||
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(NoisySpeechModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noisy speech recognition
|
||||
self.gru1 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noisy_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noisy_gru2'
|
||||
)
|
||||
|
||||
# Output classification layer
|
||||
self.output_layer = layers.Dense(
|
||||
n_classes,
|
||||
kernel_initializer='glorot_uniform',
|
||||
name='noisy_output'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, states=None, return_state=False, training=None):
|
||||
"""Forward pass - no day-specific layers for noise processing"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# Two-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.output_layer(output)
|
||||
|
||||
if return_state:
|
||||
return logits, [h1_final, h2_final]
|
||||
return logits
|
||||
|
||||
|
||||
class TripleGRUDecoder(keras.Model):
|
||||
"""
|
||||
Three-model adversarial architecture for neural speech decoding
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
|
||||
Combines:
|
||||
- NoiseModel: estimates noise in neural data
|
||||
- CleanSpeechModel: processes denoised signal for recognition
|
||||
- NoisySpeechModel: processes noise signal for recognition
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(TripleGRUDecoder, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_classes = n_classes
|
||||
self.n_days = n_days
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Create the three models
|
||||
self.noise_model = NoiseModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='noise_model'
|
||||
)
|
||||
|
||||
self.clean_speech_model = CleanSpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='clean_speech_model'
|
||||
)
|
||||
|
||||
self.noisy_speech_model = NoisySpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='noisy_speech_model'
|
||||
)
|
||||
|
||||
# Training mode flag
|
||||
self.training_mode = 'full' # 'full', 'inference'
|
||||
|
||||
def _apply_preprocessing(self, x, day_idx):
|
||||
"""Apply preprocessing using clean speech model's day layers"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases
|
||||
all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0)
|
||||
all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0)
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx)
|
||||
day_biases = tf.gather(all_day_biases, day_idx)
|
||||
day_biases = tf.expand_dims(day_biases, axis=1)
|
||||
|
||||
# Apply transformation
|
||||
x_processed = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x_processed = self.clean_speech_model._apply_patch_processing(x_processed)
|
||||
|
||||
return x_processed
|
||||
|
||||
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None):
|
||||
"""Forward pass for CleanSpeechModel with already processed input"""
|
||||
batch_size = tf.shape(x_processed)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1])
|
||||
h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1])
|
||||
states = [h1_init, h2_init, h3_init]
|
||||
else:
|
||||
h1_init, h2_init, h3_init = states
|
||||
|
||||
# GRU forward pass (skip preprocessing since input is already processed)
|
||||
output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
||||
output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
||||
output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.clean_speech_model.output_layer(output)
|
||||
return logits
|
||||
|
||||
def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None):
|
||||
"""Forward pass for NoisySpeechModel with already processed input"""
|
||||
batch_size = tf.shape(x_processed)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1])
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# GRU forward pass
|
||||
output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.noisy_speech_model.output_layer(output)
|
||||
return logits
|
||||
|
||||
def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None):
|
||||
"""
|
||||
Three-model adversarial forward pass optimized for TPU compilation
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, time_steps, neural_dim]
|
||||
day_idx: Day indices [batch_size]
|
||||
states: Dictionary with 'noise', 'clean', 'noisy' states or None
|
||||
return_state: Whether to return hidden states
|
||||
mode: 'full' for training (all three models), 'inference' for inference
|
||||
grl_lambda: Gradient reversal strength for adversarial training
|
||||
training: Training mode flag
|
||||
"""
|
||||
|
||||
if mode == 'full':
|
||||
# Training mode: run all three models
|
||||
|
||||
# 1. Noise model estimates noise in the data
|
||||
noise_output, noise_hidden = self.noise_model(
|
||||
x, day_idx,
|
||||
states['noise'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 2. Apply preprocessing to get x in the same space as noise_output
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
|
||||
# 3. Clean speech model processes denoised signal
|
||||
denoised_input = x_processed - noise_output # Residual connection
|
||||
clean_logits = self._clean_forward_with_processed_input(
|
||||
denoised_input, day_idx,
|
||||
states['clean'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 4. Noisy speech model processes noise signal
|
||||
# Apply Gradient Reversal Layer if specified
|
||||
if grl_lambda > 0.0:
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda)
|
||||
else:
|
||||
noisy_input = noise_output
|
||||
|
||||
noisy_logits = self._noisy_forward_with_processed_input(
|
||||
noisy_input,
|
||||
states['noisy'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Return results
|
||||
if return_state:
|
||||
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
||||
return clean_logits, noisy_logits, noise_output
|
||||
|
||||
elif mode == 'inference':
|
||||
# Inference mode: only noise model + clean speech model
|
||||
|
||||
# 1. Estimate noise
|
||||
noise_output, noise_hidden = self.noise_model(
|
||||
x, day_idx,
|
||||
states['noise'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 2. Apply preprocessing for residual connection
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
denoised_input = x_processed - noise_output
|
||||
clean_logits = self._clean_forward_with_processed_input(
|
||||
denoised_input, day_idx,
|
||||
states['clean'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Return results
|
||||
if return_state:
|
||||
return clean_logits, noise_hidden
|
||||
return clean_logits
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the operating mode"""
|
||||
self.training_mode = mode
|
||||
|
||||
|
||||
# Custom CTC Loss for TensorFlow TPU
|
||||
class CTCLoss(keras.losses.Loss):
|
||||
"""
|
||||
Custom CTC Loss optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self, blank_index=0, reduction='none', **kwargs):
|
||||
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
|
||||
self.blank_index = blank_index
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""
|
||||
Args:
|
||||
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
|
||||
y_pred: Logits tensor [batch_size, time_steps, num_classes]
|
||||
"""
|
||||
labels = y_true['labels']
|
||||
input_lengths = y_true['input_lengths']
|
||||
label_lengths = y_true['label_lengths']
|
||||
|
||||
# Convert logits to log probabilities
|
||||
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
||||
|
||||
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
||||
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
||||
|
||||
# Compute CTC loss
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=labels,
|
||||
logits=log_probs,
|
||||
label_length=label_lengths,
|
||||
logit_length=input_lengths,
|
||||
blank_index=self.blank_index,
|
||||
logits_time_major=True
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# TPU Strategy Helper Functions
|
||||
def create_tpu_strategy():
|
||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||
try:
|
||||
# Initialize TPU
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
|
||||
# Create TPU strategy
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"TPU initialized successfully. Number of replicas: {strategy.num_replicas_in_sync}")
|
||||
return strategy
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Failed to initialize TPU: {e}")
|
||||
print("Falling back to default strategy")
|
||||
return tf.distribute.get_strategy()
|
||||
|
||||
|
||||
def build_model_for_tpu(config):
|
||||
"""
|
||||
Build TripleGRUDecoder model optimized for TPU v5e-8
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing model parameters
|
||||
|
||||
Returns:
|
||||
Compiled Keras model ready for TPU training
|
||||
"""
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=config['model']['n_input_features'],
|
||||
n_units=config['model']['n_units'],
|
||||
n_days=len(config['dataset']['sessions']),
|
||||
n_classes=config['dataset']['n_classes'],
|
||||
rnn_dropout=config['model']['rnn_dropout'],
|
||||
input_dropout=config['model']['input_network']['input_layer_dropout'],
|
||||
patch_size=config['model']['patch_size'],
|
||||
patch_stride=config['model']['patch_stride']
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Mixed Precision Configuration for TPU v5e-8
|
||||
def configure_mixed_precision():
|
||||
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
||||
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
||||
keras.mixed_precision.set_global_policy(policy)
|
||||
print(f"Mixed precision policy set to: {policy.name}")
|
||||
return policy
|
||||
Reference in New Issue
Block a user