forked from openvpi/DiffSinger
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
congratulations for our new ACL 2022 work. & some updates
- Loading branch information
1 parent
b7c14f1
commit fa17f25
Showing
7 changed files
with
200 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import math | ||
import torch | ||
|
||
|
||
class PositionalEncoding(torch.nn.Module): | ||
"""Positional encoding. | ||
Args: | ||
d_model (int): Embedding dimension. | ||
dropout_rate (float): Dropout rate. | ||
max_len (int): Maximum input length. | ||
reverse (bool): Whether to reverse the input position. | ||
""" | ||
|
||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): | ||
"""Construct an PositionalEncoding object.""" | ||
super(PositionalEncoding, self).__init__() | ||
self.d_model = d_model | ||
self.reverse = reverse | ||
self.xscale = math.sqrt(self.d_model) | ||
self.dropout = torch.nn.Dropout(p=dropout_rate) | ||
self.pe = None | ||
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) | ||
|
||
def extend_pe(self, x): | ||
"""Reset the positional encodings.""" | ||
if self.pe is not None: | ||
if self.pe.size(1) >= x.size(1): | ||
if self.pe.dtype != x.dtype or self.pe.device != x.device: | ||
self.pe = self.pe.to(dtype=x.dtype, device=x.device) | ||
return | ||
pe = torch.zeros(x.size(1), self.d_model) | ||
if self.reverse: | ||
position = torch.arange( | ||
x.size(1) - 1, -1, -1.0, dtype=torch.float32 | ||
).unsqueeze(1) | ||
else: | ||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) | ||
div_term = torch.exp( | ||
torch.arange(0, self.d_model, 2, dtype=torch.float32) | ||
* -(math.log(10000.0) / self.d_model) | ||
) | ||
pe[:, 0::2] = torch.sin(position * div_term) | ||
pe[:, 1::2] = torch.cos(position * div_term) | ||
pe = pe.unsqueeze(0) | ||
self.pe = pe.to(device=x.device, dtype=x.dtype) | ||
|
||
def forward(self, x: torch.Tensor): | ||
"""Add positional encoding. | ||
Args: | ||
x (torch.Tensor): Input tensor (batch, time, `*`). | ||
Returns: | ||
torch.Tensor: Encoded tensor (batch, time, `*`). | ||
""" | ||
self.extend_pe(x) | ||
x = x * self.xscale + self.pe[:, : x.size(1)] | ||
return self.dropout(x) | ||
|
||
|
||
class ScaledPositionalEncoding(PositionalEncoding): | ||
"""Scaled positional encoding module. | ||
See Sec. 3.2 https://arxiv.org/abs/1809.08895 | ||
Args: | ||
d_model (int): Embedding dimension. | ||
dropout_rate (float): Dropout rate. | ||
max_len (int): Maximum input length. | ||
""" | ||
|
||
def __init__(self, d_model, dropout_rate, max_len=5000): | ||
"""Initialize class.""" | ||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) | ||
self.alpha = torch.nn.Parameter(torch.tensor(1.0)) | ||
|
||
def reset_parameters(self): | ||
"""Reset parameters.""" | ||
self.alpha.data = torch.tensor(1.0) | ||
|
||
def forward(self, x): | ||
"""Add positional encoding. | ||
Args: | ||
x (torch.Tensor): Input tensor (batch, time, `*`). | ||
Returns: | ||
torch.Tensor: Encoded tensor (batch, time, `*`). | ||
""" | ||
self.extend_pe(x) | ||
x = x + self.alpha * self.pe[:, : x.size(1)] | ||
return self.dropout(x) | ||
|
||
|
||
class RelPositionalEncoding(PositionalEncoding): | ||
"""Relative positional encoding module. | ||
See : Appendix B in https://arxiv.org/abs/1901.02860 | ||
Args: | ||
d_model (int): Embedding dimension. | ||
dropout_rate (float): Dropout rate. | ||
max_len (int): Maximum input length. | ||
""" | ||
|
||
def __init__(self, d_model, dropout_rate, max_len=5000): | ||
"""Initialize class.""" | ||
super().__init__(d_model, dropout_rate, max_len, reverse=True) | ||
|
||
def forward(self, x): | ||
"""Compute positional encoding. | ||
Args: | ||
x (torch.Tensor): Input tensor (batch, time, `*`). | ||
Returns: | ||
torch.Tensor: Encoded tensor (batch, time, `*`). | ||
torch.Tensor: Positional embedding tensor (1, time, `*`). | ||
""" | ||
self.extend_pe(x) | ||
x = x * self.xscale | ||
pos_emb = self.pe[:, : x.size(1)] | ||
return self.dropout(x) + self.dropout(pos_emb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
base_config: | ||
- usr/configs/popcs_ds_beta6.yaml | ||
- usr/configs/midi/cascade/opencs/opencpop_statis.yaml | ||
|
||
binarizer_cls: data_gen.singing.binarize.OpencpopBinarizer | ||
binary_data_dir: 'data/binary/opencpop-midi-dp' | ||
|
||
#switch_midi2f0_step: 174000 | ||
use_midi: true # for midi exp | ||
use_gt_dur: false # for further midi exp | ||
lambda_ph_dur: 1.0 | ||
lambda_sent_dur: 1.0 | ||
lambda_word_dur: 1.0 | ||
predictor_grad: 0.1 | ||
dur_predictor_layers: 5 # * | ||
|
||
|
||
fs2_ckpt: '' # | ||
#num_valid_plots: 0 | ||
task_cls: usr.diffsinger_task.DiffSingerMIDITask | ||
|
||
K_step: 100 | ||
max_tokens: 40000 | ||
max_updates: 160000 | ||
gaussian_start: True | ||
|
||
use_pitch_embed: false | ||
use_gt_f0: false # for midi exp | ||
|
||
lambda_f0: 0. | ||
lambda_uv: 0. | ||
dilation_cycle_length: 4 # * | ||
predictor_layers: 5 | ||
pe_enable: true | ||
pe_ckpt: 'checkpoints/0102_xiaoma_pe' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
base_config: | ||
- usr/configs/popcs_ds_beta6.yaml | ||
- usr/configs/midi/cascade/opencs/opencpop_statis.yaml | ||
|
||
binarizer_cls: data_gen.singing.binarize.OpencpopBinarizer | ||
binary_data_dir: 'data/binary/opencpop-midi-dp' | ||
|
||
#switch_midi2f0_step: 174000 | ||
use_midi: true # for midi exp | ||
use_gt_dur: false # for further midi exp | ||
lambda_ph_dur: 1.0 | ||
lambda_sent_dur: 1.0 | ||
lambda_word_dur: 1.0 | ||
predictor_grad: 0.1 | ||
dur_predictor_layers: 5 # * | ||
|
||
|
||
fs2_ckpt: '' # | ||
#num_valid_plots: 0 | ||
task_cls: usr.diffsinger_task.DiffSingerMIDITask | ||
|
||
K_step: 100 | ||
max_tokens: 40000 | ||
max_updates: 160000 | ||
gaussian_start: True | ||
|
||
use_pitch_embed: false | ||
use_gt_f0: false # for midi exp | ||
|
||
lambda_f0: 0. | ||
lambda_uv: 0. | ||
dilation_cycle_length: 4 # * | ||
rel_pos: true | ||
predictor_layers: 5 | ||
pe_enable: true | ||
pe_ckpt: 'checkpoints/0102_xiaoma_pe' |