-
Notifications
You must be signed in to change notification settings - Fork 153
/
model.py
156 lines (135 loc) · 5.58 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torch.nn import functional
from audio_zen.acoustics.feature import drop_band
from audio_zen.model.base_model import BaseModel
from audio_zen.model.module.sequence_model import SequenceModel
class Model(BaseModel):
def __init__(
self,
num_freqs,
look_ahead,
sequence_model,
fb_num_neighbors,
sb_num_neighbors,
fb_output_activate_function,
sb_output_activate_function,
fb_model_hidden_size,
sb_model_hidden_size,
norm_type="offline_laplace_norm",
num_groups_in_drop_band=2,
weight_init=True,
):
"""FullSubNet model (cIRM mask).
Args:
num_freqs: Frequency dim of the input
look_ahead: Number of use of the future frames
fb_num_neighbors: How much neighbor frequencies at each side from fullband model's output
sb_num_neighbors: How much neighbor frequencies at each side from noisy spectrogram
sequence_model: Chose one sequence model as the basic model e.g., GRU, LSTM
fb_output_activate_function: fullband model's activation function
sb_output_activate_function: subband model's activation function
norm_type: type of normalization, see more details in "BaseModel" class
"""
super().__init__()
assert sequence_model in (
"GRU",
"LSTM",
), f"{self.__class__.__name__} only support GRU and LSTM."
self.fb_model = SequenceModel(
input_size=num_freqs,
output_size=num_freqs,
hidden_size=fb_model_hidden_size,
num_layers=2,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=fb_output_activate_function,
)
self.sb_model = SequenceModel(
input_size=(sb_num_neighbors * 2 + 1) + (fb_num_neighbors * 2 + 1),
output_size=2,
hidden_size=sb_model_hidden_size,
num_layers=2,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=sb_output_activate_function,
)
self.sb_num_neighbors = sb_num_neighbors
self.fb_num_neighbors = fb_num_neighbors
self.look_ahead = look_ahead
self.norm = self.norm_wrapper(norm_type)
self.num_groups_in_drop_band = num_groups_in_drop_band
if weight_init:
self.apply(self.weight_init)
def forward(self, noisy_mag):
"""
Args:
noisy_mag: noisy magnitude spectrogram
Returns:
The real part and imag part of the enhanced spectrogram
Shapes:
noisy_mag: [B, 1, F, T]
return: [B, 2, F, T]
"""
assert noisy_mag.dim() == 4
noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
assert (
num_channels == 1
), f"{self.__class__.__name__} takes the mag feature as inputs."
# Fullband model
fb_input = self.norm(noisy_mag).reshape(
batch_size, num_channels * num_freqs, num_frames
)
fb_output = self.fb_model(fb_input).reshape(batch_size, 1, num_freqs, num_frames)
# Unfold fullband model's output, [B, N=F, C, F_f, T]. N is the number of sub-band units
fb_output_unfolded = self.freq_unfold(fb_output, num_neighbors=self.fb_num_neighbors)
fb_output_unfolded = fb_output_unfolded.reshape(
batch_size, num_freqs, self.fb_num_neighbors * 2 + 1, num_frames
)
# Unfold noisy spectrogram, [B, N=F, C, F_s, T]
noisy_mag_unfolded = self.freq_unfold(noisy_mag, num_neighbors=self.sb_num_neighbors)
noisy_mag_unfolded = noisy_mag_unfolded.reshape(
batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, num_frames
)
# Concatenation, [B, F, (F_s + F_f), T]
sb_input = torch.cat([noisy_mag_unfolded, fb_output_unfolded], dim=2)
sb_input = self.norm(sb_input)
# Speeding up training without significant performance degradation.
if batch_size > 1:
sb_input = drop_band(
sb_input.permute(0, 2, 1, 3), num_groups=self.num_groups_in_drop_band
) # [B, (F_s + F_f), F//num_groups, T]
num_freqs = sb_input.shape[2]
sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, (F_s + F_f), T]
sb_input = sb_input.reshape(
batch_size * num_freqs,
(self.sb_num_neighbors * 2 + 1) + (self.fb_num_neighbors * 2 + 1),
num_frames,
)
# [B * F, (F_s + F_f), T] => [B * F, 2, T] => [B, F, 2, T]
sb_mask = self.sb_model(sb_input)
sb_mask = (
sb_mask.reshape(batch_size, num_freqs, 2, num_frames)
.permute(0, 2, 1, 3)
.contiguous()
)
output = sb_mask[:, :, :, self.look_ahead :]
return output
if __name__ == "__main__":
with torch.no_grad():
noisy_mag = torch.rand(1, 1, 257, 63)
model = Model(
num_freqs=257,
look_ahead=2,
sequence_model="LSTM",
fb_num_neighbors=0,
sb_num_neighbors=15,
fb_output_activate_function="ReLU",
sb_output_activate_function=False,
fb_model_hidden_size=512,
sb_model_hidden_size=384,
norm_type="offline_laplace_norm",
num_groups_in_drop_band=2,
weight_init=False,
)
print(model(noisy_mag).shape)