3210d75451d5d78ebe884c712f31800ca1ccd0b3,train.py,,,#,577
Before Change
// Train!
try:
train(model, data_loader, optimizer, writer,
init_lr=hparams.initial_learning_rate,
checkpoint_dir=checkpoint_dir,
checkpoint_interval=hparams.checkpoint_interval,
nepochs=hparams.nepochs,
clip_thresh=hparams.clip_thresh,
)
except KeyboardInterrupt:
save_checkpoint(
model, optimizer, global_step, checkpoint_dir, global_epoch)
After Change
// Dataset and Dataloader setup
data_loaders = {}
for phase in ["train", "test"]:
train = phase == "train"
X = FileSourceDataset(RawAudioDataSource(data_root, speaker_id=speaker_id,
train=train,
test_size=hparams.test_size,
random_state=hparams.random_state))
if local_conditioning:
Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id=speaker_id,
train=train,
test_size=hparams.test_size,
random_state=hparams.random_state))
assert len(X) == len(Mel)
print("Local conditioning enabled. Shape of a sample: {}.".format(
Mel[0].shape))
else:
Mel = None
print("[{}]: length of the dataset is {}".format(phase, len(X)))
if train:
lengths = np.array(X.file_data_source.lengths)
// Prepare sampler
sampler = PartialyRandomizedSimilarTimeLengthSampler(
lengths, batch_size=hparams.batch_size)
else:
sampler = None
dataset = PyTorchDataset(X, Mel)
data_loader = data_utils.DataLoader(
dataset, batch_size=hparams.batch_size,
num_workers=hparams.num_workers, sampler=sampler,
collate_fn=collate_fn, pin_memory=hparams.pin_memory)
data_loaders[phase] = data_loader
// Model
model = build_model()
print(model)
if use_cuda:
model = model.cuda()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 4
Instances
Project Name: r9y9/wavenet_vocoder
Commit Name: 3210d75451d5d78ebe884c712f31800ca1ccd0b3
Time: 2018-01-03
Author: zryuichi@gmail.com
File Name: train.py
Class Name:
Method Name:
Project Name: snorkel-team/snorkel
Commit Name: 0722cbb28234dfb6f38c375fed80d339e9921324
Time: 2017-08-14
Author: ajratner@gmail.com
File Name: snorkel/learning/utils.py
Class Name: GridSearch
Method Name: _fit_st
Project Name: jindongwang/transferlearning
Commit Name: 376b01c2e338ec63e638f62a76d67f6a9323e47c
Time: 2019-08-14
Author: jindongwang@outlook.com
File Name: code/deep/DeepCoral/DeepCoral.py
Class Name:
Method Name: