final version? maybe
This commit is contained in:
@@ -56,28 +56,37 @@ class NoiseModel(nn.Module):
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
||||
|
||||
def forward(self, x, day_idx, states=None):
|
||||
# Apply day-specific transformation
|
||||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||||
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
batch_size = x.size(0)
|
||||
|
||||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
# Stack all day weights and biases upfront for static indexing
|
||||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x = torch.bmm(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# XLA-friendly conditional dropout
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
# Apply patch processing if enabled (keep conditional for now, optimize later)
|
||||
if self.patch_size > 0:
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
|
||||
# Initialize hidden states
|
||||
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||
if states is None:
|
||||
states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
|
||||
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -146,11 +155,19 @@ class CleanSpeechModel(nn.Module):
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
|
||||
def forward(self, x, day_idx, states=None, return_state=False):
|
||||
# Apply day-specific transformation
|
||||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||||
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
batch_size = x.size(0)
|
||||
|
||||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
# Stack all day weights and biases upfront for static indexing
|
||||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x = torch.bmm(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
if self.input_dropout > 0:
|
||||
@@ -163,11 +180,11 @@ class CleanSpeechModel(nn.Module):
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
|
||||
# Initialize hidden states
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(3, x.shape[0], self.n_units).contiguous()
|
||||
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -235,10 +252,11 @@ class NoisySpeechModel(nn.Module):
|
||||
|
||||
def forward(self, x, states=None, return_state=False):
|
||||
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Initialize hidden states
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(2, x.shape[0], self.n_units).contiguous()
|
||||
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -329,30 +347,39 @@ class TripleGRUDecoder(nn.Module):
|
||||
self.training_mode = 'full' # 'full', 'inference'
|
||||
|
||||
def _apply_preprocessing(self, x, day_idx):
|
||||
'''Apply day-specific transformation and patch processing to match what models expect'''
|
||||
# Apply day-specific transformation (same as in each model)
|
||||
day_weights = torch.stack([self.clean_speech_model.day_weights[i] for i in day_idx], dim=0)
|
||||
day_biases = torch.cat([self.clean_speech_model.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||||
'''XLA-friendly preprocessing with static operations'''
|
||||
batch_size = x.size(0)
|
||||
|
||||
x_processed = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x_processed = torch.bmm(x, day_weights) + day_biases
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled (same as in each model)
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x_processed = x_processed.unsqueeze(1)
|
||||
x_processed = x_processed.permute(0, 3, 1, 2)
|
||||
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x_processed = x_unfold.reshape(x_processed.size(0), x_unfold.size(1), -1)
|
||||
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
|
||||
return x_processed
|
||||
|
||||
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
|
||||
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
||||
# Initialize hidden states
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.clean_speech_model.h0.expand(3, x_processed.shape[0], self.clean_speech_model.n_units).contiguous()
|
||||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||
|
||||
# GRU forward pass (skip preprocessing since input is already processed)
|
||||
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||
@@ -363,9 +390,11 @@ class TripleGRUDecoder(nn.Module):
|
||||
|
||||
def _noisy_forward_with_processed_input(self, x_processed, states=None):
|
||||
'''Forward pass for NoisySpeechModel with already processed input'''
|
||||
# Initialize hidden states
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.noisy_speech_model.h0.expand(2, x_processed.shape[0], self.noisy_speech_model.n_units).contiguous()
|
||||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||
|
||||
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||
@@ -407,23 +436,10 @@ class TripleGRUDecoder(nn.Module):
|
||||
noisy_logits = self._noisy_forward_with_processed_input(noise_output,
|
||||
states['noisy'] if states else None)
|
||||
|
||||
# XLA-friendly return - use tuple instead of dict for better compilation
|
||||
if return_state:
|
||||
return_states = {
|
||||
'noise': noise_hidden,
|
||||
'clean': None, # CleanSpeechModel doesn't return hidden states in this call
|
||||
'noisy': None # NoisySpeechModel doesn't return hidden states in this call
|
||||
}
|
||||
return {
|
||||
'clean_logits': clean_logits,
|
||||
'noisy_logits': noisy_logits,
|
||||
'noise_output': noise_output
|
||||
}, return_states
|
||||
|
||||
return {
|
||||
'clean_logits': clean_logits,
|
||||
'noisy_logits': noisy_logits,
|
||||
'noise_output': noise_output
|
||||
}
|
||||
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
||||
return clean_logits, noisy_logits, noise_output
|
||||
|
||||
elif mode == 'inference':
|
||||
# Inference mode: only noise model + clean speech model
|
||||
@@ -440,13 +456,9 @@ class TripleGRUDecoder(nn.Module):
|
||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||
states['clean'] if states else None)
|
||||
|
||||
# XLA-friendly return - use tuple for consistency
|
||||
if return_state:
|
||||
return_states = {
|
||||
'noise': noise_hidden,
|
||||
'clean': None
|
||||
}
|
||||
return clean_logits, return_states
|
||||
|
||||
return clean_logits, noise_hidden
|
||||
return clean_logits
|
||||
|
||||
else:
|
||||
|
Reference in New Issue
Block a user