with torch.no_grad():
mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.hop_length
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
mels = self.fold_with_overlap(mels, target, overlap)
if aux is not None:
aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.rnn_dims).cuda()
h2 = torch.zeros(b_size, self.rnn_dims).cuda()
x = torch.zeros(b_size, 1).cuda()
if self.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
for i in range(seq_len):
m_t = mels[:, i, :]
if self.use_aux_net:
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = (
torch.cat([x, m_t, a1_t], dim=1)
if self.use_aux_net
else torch.cat([x, m_t], dim=1)
)
x = self.I(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
x = F.relu(self.fc2(x))
logits = self.fc3(x)
if self.mode == "mold":
sample = sample_from_discretized_mix_logistic(
logits.unsqueeze(0).transpose(1, 2)
)
output.append(sample.view(-1))
x = sample.transpose(0, 1).cuda()
elif self.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).cuda()
elif isinstance(self.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
output.append(sample)
x = sample.unsqueeze(-1)
else:
After Change
rnn2 = self.get_gru_cell(self.rnn2)
with torch.no_grad():
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
//mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.hop_length
mels = self.pad_tensor(mels.transpose(
1, 2), pad=self.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
mels = self.fold_with_overlap(mels, target, overlap)
if aux is not None:
aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
x = torch.zeros(b_size, 1).to(device)
if self.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)]
for i in range(seq_len):
m_t = mels[:, i, :]
if self.use_aux_net:
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = (
torch.cat([x, m_t, a1_t], dim=1)
if self.use_aux_net
else torch.cat([x, m_t], dim=1)
)
x = self.I(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
x = F.relu(self.fc2(x))
logits = self.fc3(x)
if self.mode == "mold":
sample = sample_from_discretized_mix_logistic(
logits.unsqueeze(0).transpose(1, 2)
)
output.append(sample.view(-1))
x = sample.transpose(0, 1).to(device)
elif self.mode == "gauss":
sample = sample_from_gaussian(
logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).to(device)
elif isinstance(self.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
output.append(sample)
x = sample.unsqueeze(-1)
else: