Files
b2txt25/model_training_nnn_tpu/rnn_model_tf.py
2025-10-15 23:37:24 +08:00

871 lines
30 KiB
Python

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
@tf.custom_gradient
def gradient_reverse(x, lambd=1.0):
"""
Gradient Reversal Layer (GRL) for TensorFlow
Forward: identity
Backward: multiply incoming gradient by -lambda
"""
def grad(dy):
return -lambd * dy # Only return gradient w.r.t. x, not lambd
return tf.identity(x), grad
class NoiseModel(keras.Model):
"""
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
TensorFlow/Keras implementation optimized for TPU v5e-8
"""
def __init__(self,
neural_dim,
n_units,
n_days,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
**kwargs):
super(NoiseModel, self).__init__(**kwargs)
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers - use Variables for TPU compatibility
self.day_layer_activation = layers.Activation('softsign')
# Initialize day-specific weights and biases as Variables
self.day_weights = []
self.day_biases = []
for i in range(n_days):
weight = self.add_weight(
name=f'day_weight_{i}',
shape=(neural_dim, neural_dim),
initializer=tf.keras.initializers.Identity(),
trainable=True
)
bias = self.add_weight(
name=f'day_bias_{i}',
shape=(neural_dim,),
initializer=tf.keras.initializers.Zeros(),
trainable=True
)
self.day_weights.append(weight)
self.day_biases.append(bias)
self.day_layer_dropout = layers.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noise estimation
# Use separate GRU layers for better TPU performance
self.gru1 = layers.GRU(
units=self.input_size,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0, # Avoid recurrent dropout on TPU
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='noise_gru1'
)
self.gru2 = layers.GRU(
units=self.input_size,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='noise_gru2'
)
# Learnable initial hidden states
self.h0_1 = self.add_weight(
name='h0_1',
shape=(1, self.input_size),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
self.h0_2 = self.add_weight(
name='h0_2',
shape=(1, self.input_size),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
def call(self, x, day_idx, states=None, training=None):
"""
Forward pass optimized for TPU compilation
Args:
x: Input tensor [batch_size, time_steps, neural_dim]
day_idx: Day indices [batch_size]
states: Optional initial states
training: Training mode flag
"""
batch_size = tf.shape(x)[0]
# Stack all day weights and biases for efficient gathering
all_day_weights = tf.stack(self.day_weights, axis=0) # [n_days, neural_dim, neural_dim]
all_day_biases = tf.stack(self.day_biases, axis=0) # [n_days, neural_dim]
# Gather day-specific parameters
day_weights = tf.gather(all_day_weights, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = tf.gather(all_day_biases, day_idx) # [batch_size, neural_dim]
# Add time dimension to biases for broadcasting
day_biases = tf.expand_dims(day_biases, axis=1) # [batch_size, 1, neural_dim]
# Apply day-specific transformation using efficient batch matrix multiplication
x = tf.linalg.matmul(x, day_weights) + day_biases
x = self.day_layer_activation(x)
# Apply input dropout
if training and self.input_dropout > 0:
x = self.day_layer_dropout(x, training=training)
# Apply patch processing if enabled
if self.patch_size > 0:
x = self._apply_patch_processing(x)
# Initialize hidden states if not provided
if states is None:
h1_init = tf.tile(self.h0_1, [batch_size, 1]) # [batch_size, input_size]
h2_init = tf.tile(self.h0_2, [batch_size, 1]) # [batch_size, input_size]
states = [h1_init, h2_init]
else:
h1_init, h2_init = states
# Two-layer GRU forward pass
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
return output, [h1_final, h2_final]
def _apply_patch_processing(self, x):
"""Apply patch processing using TensorFlow operations"""
batch_size = tf.shape(x)[0]
time_steps = tf.shape(x)[1]
# Add channel dimension for conv1d operations
x = tf.expand_dims(x, axis=2) # [batch_size, time_steps, 1, neural_dim]
# Extract patches using extract_patches
# This is equivalent to PyTorch's unfold operation
patch_x = tf.image.extract_patches(
x,
sizes=[1, self.patch_size, 1, 1],
strides=[1, self.patch_stride, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID'
)
# Reshape to match expected output
new_time_steps = tf.shape(patch_x)[1]
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
return patch_x
class CleanSpeechModel(keras.Model):
"""
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
TensorFlow/Keras implementation optimized for TPU v5e-8
"""
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
**kwargs):
super(CleanSpeechModel, self).__init__(**kwargs)
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = layers.Activation('softsign')
# Initialize day-specific weights and biases
self.day_weights = []
self.day_biases = []
for i in range(n_days):
weight = self.add_weight(
name=f'day_weight_{i}',
shape=(neural_dim, neural_dim),
initializer=tf.keras.initializers.Identity(),
trainable=True
)
bias = self.add_weight(
name=f'day_bias_{i}',
shape=(neural_dim,),
initializer=tf.keras.initializers.Zeros(),
trainable=True
)
self.day_weights.append(weight)
self.day_biases.append(bias)
self.day_layer_dropout = layers.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 3-layer GRU for clean speech recognition
self.gru1 = layers.GRU(
units=n_units,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='clean_gru1'
)
self.gru2 = layers.GRU(
units=n_units,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='clean_gru2'
)
self.gru3 = layers.GRU(
units=n_units,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='clean_gru3'
)
# Output classification layer
self.output_layer = layers.Dense(
n_classes,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
name='clean_output'
)
# Learnable initial hidden states
self.h0_1 = self.add_weight(
name='h0_1',
shape=(1, n_units),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
self.h0_2 = self.add_weight(
name='h0_2',
shape=(1, n_units),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
self.h0_3 = self.add_weight(
name='h0_3',
shape=(1, n_units),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
def call(self, x, day_idx, states=None, return_state=False, training=None):
"""Forward pass optimized for TPU compilation"""
batch_size = tf.shape(x)[0]
# Stack all day weights and biases for efficient gathering
all_day_weights = tf.stack(self.day_weights, axis=0)
all_day_biases = tf.stack(self.day_biases, axis=0)
# Gather day-specific parameters
day_weights = tf.gather(all_day_weights, day_idx)
day_biases = tf.gather(all_day_biases, day_idx)
day_biases = tf.expand_dims(day_biases, axis=1)
# Apply day-specific transformation
x = tf.linalg.matmul(x, day_weights) + day_biases
x = self.day_layer_activation(x)
if training and self.input_dropout > 0:
x = self.day_layer_dropout(x, training=training)
# Apply patch processing if enabled
if self.patch_size > 0:
x = self._apply_patch_processing(x)
# Initialize hidden states if not provided
if states is None:
h1_init = tf.tile(self.h0_1, [batch_size, 1])
h2_init = tf.tile(self.h0_2, [batch_size, 1])
h3_init = tf.tile(self.h0_3, [batch_size, 1])
states = [h1_init, h2_init, h3_init]
else:
h1_init, h2_init, h3_init = states
# Three-layer GRU forward pass
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
output, h3_final = self.gru3(output2, initial_state=h3_init, training=training)
# Classification
logits = self.output_layer(output)
if return_state:
return logits, [h1_final, h2_final, h3_final]
return logits
def _apply_patch_processing(self, x):
"""Apply patch processing using TensorFlow operations"""
batch_size = tf.shape(x)[0]
# Add channel dimension
x = tf.expand_dims(x, axis=2)
# Extract patches
patch_x = tf.image.extract_patches(
x,
sizes=[1, self.patch_size, 1, 1],
strides=[1, self.patch_stride, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID'
)
# Reshape
new_time_steps = tf.shape(patch_x)[1]
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
return patch_x
class NoisySpeechModel(keras.Model):
"""
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
TensorFlow/Keras implementation optimized for TPU v5e-8
"""
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
**kwargs):
super(NoisySpeechModel, self).__init__(**kwargs)
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noisy speech recognition
self.gru1 = layers.GRU(
units=n_units,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='noisy_gru1'
)
self.gru2 = layers.GRU(
units=n_units,
return_sequences=True,
return_state=True,
dropout=self.rnn_dropout,
recurrent_dropout=0.0,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
recurrent_initializer=tf.keras.initializers.Orthogonal(),
name='noisy_gru2'
)
# Output classification layer
self.output_layer = layers.Dense(
n_classes,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
name='noisy_output'
)
# Learnable initial hidden states
self.h0_1 = self.add_weight(
name='h0_1',
shape=(1, n_units),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
self.h0_2 = self.add_weight(
name='h0_2',
shape=(1, n_units),
initializer=tf.keras.initializers.GlorotUniform(),
trainable=True
)
def call(self, x, states=None, return_state=False, training=None):
"""Forward pass - no day-specific layers for noise processing"""
batch_size = tf.shape(x)[0]
# Initialize hidden states if not provided
if states is None:
h1_init = tf.tile(self.h0_1, [batch_size, 1])
h2_init = tf.tile(self.h0_2, [batch_size, 1])
states = [h1_init, h2_init]
else:
h1_init, h2_init = states
# Two-layer GRU forward pass
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
# Classification
logits = self.output_layer(output)
if return_state:
return logits, [h1_final, h2_final]
return logits
class TripleGRUDecoder(keras.Model):
"""
Three-model adversarial architecture for neural speech decoding
TensorFlow/Keras implementation optimized for TPU v5e-8
Combines:
- NoiseModel: estimates noise in neural data
- CleanSpeechModel: processes denoised signal for recognition
- NoisySpeechModel: processes noise signal for recognition
"""
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
**kwargs):
super(TripleGRUDecoder, self).__init__(**kwargs)
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Create the three models
self.noise_model = NoiseModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride,
name='noise_model'
)
self.clean_speech_model = CleanSpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride,
name='clean_speech_model'
)
self.noisy_speech_model = NoisySpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride,
name='noisy_speech_model'
)
# Training mode flag
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
"""Apply preprocessing using clean speech model's day layers"""
batch_size = tf.shape(x)[0]
# Stack all day weights and biases
all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0)
all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0)
# Gather day-specific parameters
day_weights = tf.gather(all_day_weights, day_idx)
day_biases = tf.gather(all_day_biases, day_idx)
day_biases = tf.expand_dims(day_biases, axis=1)
# Apply transformation
x_processed = tf.linalg.matmul(x, day_weights) + day_biases
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled
if self.patch_size > 0:
x_processed = self.clean_speech_model._apply_patch_processing(x_processed)
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None):
"""Forward pass for CleanSpeechModel with already processed input"""
batch_size = tf.shape(x_processed)[0]
# Initialize hidden states if not provided
if states is None:
h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1])
h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1])
h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1])
states = [h1_init, h2_init, h3_init]
else:
h1_init, h2_init, h3_init = states
# GRU forward pass (skip preprocessing since input is already processed)
output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training)
output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training)
# Classification
logits = self.clean_speech_model.output_layer(output)
return logits
def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None):
"""Forward pass for NoisySpeechModel with already processed input"""
batch_size = tf.shape(x_processed)[0]
# Initialize hidden states if not provided
if states is None:
h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1])
h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1])
states = [h1_init, h2_init]
else:
h1_init, h2_init = states
# GRU forward pass
output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training)
# Classification
logits = self.noisy_speech_model.output_layer(output)
return logits
def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None):
"""
Three-model adversarial forward pass optimized for TPU compilation
Args:
x: Input tensor [batch_size, time_steps, neural_dim]
day_idx: Day indices [batch_size]
states: Dictionary with 'noise', 'clean', 'noisy' states or None
return_state: Whether to return hidden states
mode: 'full' for training (all three models), 'inference' for inference
grl_lambda: Gradient reversal strength for adversarial training
training: Training mode flag
"""
if mode == 'full':
# Training mode: run all three models
# 1. Noise model estimates noise in the data
noise_output, noise_hidden = self.noise_model(
x, day_idx,
states['noise'] if states else None,
training=training
)
# 2. Apply preprocessing to get x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
# 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection
clean_logits = self._clean_forward_with_processed_input(
denoised_input, day_idx,
states['clean'] if states else None,
training=training
)
# 4. Noisy speech model processes noise signal
# Apply Gradient Reversal Layer if specified
if grl_lambda > 0.0:
noisy_input = gradient_reverse(noise_output, grl_lambda)
else:
noisy_input = noise_output
noisy_logits = self._noisy_forward_with_processed_input(
noisy_input,
states['noisy'] if states else None,
training=training
)
# Return results
if return_state:
return (clean_logits, noisy_logits, noise_output), noise_hidden
return clean_logits, noisy_logits, noise_output
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
# 1. Estimate noise
noise_output, noise_hidden = self.noise_model(
x, day_idx,
states['noise'] if states else None,
training=training
)
# 2. Apply preprocessing for residual connection
x_processed = self._apply_preprocessing(x, day_idx)
denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(
denoised_input, day_idx,
states['clean'] if states else None,
training=training
)
# Return results
if return_state:
return clean_logits, noise_hidden
return clean_logits
else:
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
def set_mode(self, mode):
"""Set the operating mode"""
self.training_mode = mode
# Custom CTC Loss for TensorFlow TPU
class CTCLoss(keras.losses.Loss):
"""
Custom CTC Loss optimized for TPU v5e-8
"""
def __init__(self, blank_index=0, reduction='none', **kwargs):
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
self.blank_index = blank_index
def call(self, y_true, y_pred):
"""
Args:
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
y_pred: Logits tensor [batch_size, time_steps, num_classes]
"""
labels = y_true['labels']
input_lengths = y_true['input_lengths']
label_lengths = y_true['label_lengths']
# Ensure correct data types
labels = tf.cast(labels, tf.int32)
input_lengths = tf.cast(input_lengths, tf.int32)
label_lengths = tf.cast(label_lengths, tf.int32)
# Convert logits to log probabilities
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
# Transpose for CTC: [time_steps, batch_size, num_classes]
log_probs = tf.transpose(log_probs, [1, 0, 2])
# Convert dense labels to sparse format for CTC using TensorFlow operations
def dense_to_sparse(dense_tensor, sequence_lengths):
"""Convert dense tensor to sparse tensor for CTC"""
batch_size = tf.shape(dense_tensor)[0]
max_len = tf.shape(dense_tensor)[1]
# Create mask for non-zero elements
mask = tf.not_equal(dense_tensor, 0)
# Get indices of non-zero elements
indices = tf.where(mask)
# Get values at those indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast([batch_size, max_len], tf.int64)
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
# Convert labels to sparse format
sparse_labels = dense_to_sparse(labels, label_lengths)
# Compute CTC loss
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=log_probs,
label_length=None, # Not needed for sparse format
logit_length=input_lengths,
blank_index=self.blank_index,
logits_time_major=True
)
return loss
# TPU Strategy Helper Functions
def create_tpu_strategy():
"""Create TPU strategy for distributed training on TPU v5e-8"""
import os
print("🔍 Detecting TPU environment...")
# Disable GPU to avoid CUDA conflicts in TPU environment
try:
print("🚫 Disabling GPU to prevent CUDA conflicts...")
tf.config.set_visible_devices([], 'GPU')
print("✅ GPU disabled successfully")
except Exception as e:
print(f"⚠️ Warning: Could not disable GPU: {e}")
# Check for various TPU environment variables
tpu_address = None
tpu_name = None
# Check common TPU environment variables
if 'COLAB_TPU_ADDR' in os.environ:
tpu_address = os.environ['COLAB_TPU_ADDR']
print(f"📍 Found Colab TPU address: {tpu_address}")
elif 'TPU_NAME' in os.environ:
tpu_name = os.environ['TPU_NAME']
print(f"📍 Found TPU name: {tpu_name}")
elif 'TPU_WORKER_ID' in os.environ:
# Kaggle TPU environment
worker_id = os.environ.get('TPU_WORKER_ID', '0')
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
print(f"📍 Kaggle TPU detected, worker ID: {worker_id}, address: {tpu_address}")
# Print all TPU-related environment variables for debugging
print("🔧 TPU environment variables:")
tpu_vars = {k: v for k, v in os.environ.items() if 'TPU' in k or 'COLAB' in k}
for key, value in tpu_vars.items():
print(f" {key}={value}")
try:
# Try different TPU resolver configurations
resolver = None
if tpu_address:
print(f"🚀 Attempting TPU connection with address: {tpu_address}")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
elif tpu_name:
print(f"🚀 Attempting TPU connection with name: {tpu_name}")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
else:
# Try auto-detection
print("🚀 Attempting TPU auto-detection...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
# Initialize TPU
print("⚡ Connecting to TPU cluster...")
tf.config.experimental_connect_to_cluster(resolver)
print("🔧 Initializing TPU system...")
tf.tpu.experimental.initialize_tpu_system(resolver)
# Create TPU strategy
print("🎯 Creating TPU strategy...")
strategy = tf.distribute.TPUStrategy(resolver)
print(f"✅ TPU initialized successfully!")
print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
print(f"🏃 TPU cluster: {resolver.cluster_spec()}")
return strategy
except Exception as e:
print(f"❌ Failed to initialize TPU: {e}")
print(f"🔍 Error type: {type(e).__name__}")
# Enhanced error reporting
if "Please provide a TPU Name" in str(e):
print("💡 Hint: TPU name/address not found in environment variables")
print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
print("🔄 Falling back to default strategy (CPU/GPU)")
return tf.distribute.get_strategy()
def build_model_for_tpu(config):
"""
Build TripleGRUDecoder model optimized for TPU v5e-8
Args:
config: Configuration dictionary containing model parameters
Returns:
Compiled Keras model ready for TPU training
"""
model = TripleGRUDecoder(
neural_dim=config['model']['n_input_features'],
n_units=config['model']['n_units'],
n_days=len(config['dataset']['sessions']),
n_classes=config['dataset']['n_classes'],
rnn_dropout=config['model']['rnn_dropout'],
input_dropout=config['model']['input_network']['input_layer_dropout'],
patch_size=config['model']['patch_size'],
patch_stride=config['model']['patch_stride']
)
return model
# Mixed Precision Configuration for TPU v5e-8
def configure_mixed_precision():
"""Configure mixed precision for optimal TPU v5e-8 performance"""
policy = keras.mixed_precision.Policy('mixed_bfloat16')
keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision policy set to: {policy.name}")
return policy