From 06c4c6c267d6bce3a7b186b62e2b4b360571f84c Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 23:36:58 +0800 Subject: [PATCH] tpu --- README.md | 2 +- model_training_nnn/rnn_model.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e8f06d6..5082c3c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ - 数据集通过download_data.py文件下载。 - 代码仓库:【dev2分支】 - 个人gitea仓库:(github限制上传文件大小,哎。虽然我后面在这里也把大文件删了,,,,,)http://zchens.cn:3000/zchen/b2txt25/src/branch/dev2 - - github仓库:https://github.com/ZH-CEN/nejm-brain-to-text/tree/dev2 + - github仓库:https://github.com/ZH-CEN/nejm-brain-to-text/tree/dev2 (这个仓库好像后面没维护了) # Idea 这个模型没有记录在论文和ppt中,因为————很晚才想到,前面都在研究那个生成时构建树(只能说逻辑是可以实现的,代码在哪里呢?不知道=-=),这个目前代码主要的已经完工,在gpu环境下可以训练了。但是,参数量比baseline 还大一点点,减少batch_size后能在p100上训练,但是实在是太太太太太慢了。kaggle 的 TPU v5e-8 用起来很很不趁手。就算换5090跑,出了结果(参数量大约增了40%,乐观估计起码训练7小时)也没时间调优,甚至测评代码也没好,所以,罢了。不过我觉得模型设计还是挺好的,但我严重怀疑是有人做过,毕竟学习噪声这点好像是马老师讲的时候提过的,当时就好奇怎么学习噪声,现在才想明白。应该是有人做过了的吧。 diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 3e9e257..3ae19fb 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -68,7 +68,8 @@ class NoiseModel(nn.Module): day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] # Use bmm (batch matrix multiply) which is highly optimized in XLA - x = torch.bmm(x, day_weights) + day_biases + # Ensure dtype consistency for mixed precision training + x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x = self.day_layer_activation(x) # XLA-friendly conditional dropout @@ -167,7 +168,8 @@ class CleanSpeechModel(nn.Module): day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] # Use bmm (batch matrix multiply) which is highly optimized in XLA - x = torch.bmm(x, day_weights) + day_biases + # Ensure dtype consistency for mixed precision training + x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x = self.day_layer_activation(x) if self.input_dropout > 0: