3210d75451d5d78ebe884c712f31800ca1ccd0b3,train.py,,,#,577

Before Change



    // Train!
    try:
        train(model, data_loader, optimizer, writer,
              init_lr=hparams.initial_learning_rate,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs,
              clip_thresh=hparams.clip_thresh,
              )
    except KeyboardInterrupt:
        save_checkpoint(
            model, optimizer, global_step, checkpoint_dir, global_epoch)

After Change



    // Dataset and Dataloader setup
    data_loaders = {}
    for phase in ["train", "test"]:
        train = phase == "train"
        X = FileSourceDataset(RawAudioDataSource(data_root, speaker_id=speaker_id,
                                                 train=train,
                                                 test_size=hparams.test_size,
                                                 random_state=hparams.random_state))
        if local_conditioning:
            Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id=speaker_id,
                                                      train=train,
                                                      test_size=hparams.test_size,
                                                      random_state=hparams.random_state))
            assert len(X) == len(Mel)
            print("Local conditioning enabled. Shape of a sample: {}.".format(
                Mel[0].shape))
        else:
            Mel = None
        print("[{}]: length of the dataset is {}".format(phase, len(X)))

        if train:
            lengths = np.array(X.file_data_source.lengths)
            // Prepare sampler
            sampler = PartialyRandomizedSimilarTimeLengthSampler(
                lengths, batch_size=hparams.batch_size)
        else:
            sampler = None

        dataset = PyTorchDataset(X, Mel)
        data_loader = data_utils.DataLoader(
            dataset, batch_size=hparams.batch_size,
            num_workers=hparams.num_workers, sampler=sampler,
            collate_fn=collate_fn, pin_memory=hparams.pin_memory)

        data_loaders[phase] = data_loader

    // Model
    model = build_model()
    print(model)
    if use_cuda:
        model = model.cuda()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 4

Instances


Project Name: r9y9/wavenet_vocoder
Commit Name: 3210d75451d5d78ebe884c712f31800ca1ccd0b3
Time: 2018-01-03
Author: zryuichi@gmail.com
File Name: train.py
Class Name:
Method Name:


Project Name: snorkel-team/snorkel
Commit Name: 0722cbb28234dfb6f38c375fed80d339e9921324
Time: 2017-08-14
Author: ajratner@gmail.com
File Name: snorkel/learning/utils.py
Class Name: GridSearch
Method Name: _fit_st


Project Name: jindongwang/transferlearning
Commit Name: 376b01c2e338ec63e638f62a76d67f6a9323e47c
Time: 2019-08-14
Author: jindongwang@outlook.com
File Name: code/deep/DeepCoral/DeepCoral.py
Class Name:
Method Name: