3210d75451d5d78ebe884c712f31800ca1ccd0b3,train.py,,,#,577

Before Change



    // Train!
    try:
        train(model, data_loader, optimizer, writer,
              init_lr=hparams.initial_learning_rate,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs,
              clip_thresh=hparams.clip_thresh,
              )
    except KeyboardInterrupt:
        save_checkpoint(
            model, optimizer, global_step, checkpoint_dir, global_epoch)

After Change



    // Dataset and Dataloader setup
    data_loaders = {}
    for phase in ["train", "test"]:
        train = phase == "train"
        X = FileSourceDataset(RawAudioDataSource(data_root, speaker_id=speaker_id,
                                                 train=train,
                                                 test_size=hparams.test_size,
                                                 random_state=hparams.random_state))
        if local_conditioning:
            Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id=speaker_id,
                                                      train=train,
                                                      test_size=hparams.test_size,
                                                      random_state=hparams.random_state))
            assert len(X) == len(Mel)
            print("Local conditioning enabled. Shape of a sample: {}.".format(
                Mel[0].shape))
        else:
            Mel = None
        print("[{}]: length of the dataset is {}".format(phase, len(X)))

        if train:
            lengths = np.array(X.file_data_source.lengths)
            // Prepare sampler
            sampler = PartialyRandomizedSimilarTimeLengthSampler(
                lengths, batch_size=hparams.batch_size)
        else:
            sampler = None

        dataset = PyTorchDataset(X, Mel)
        data_loader = data_utils.DataLoader(
            dataset, batch_size=hparams.batch_size,
            num_workers=hparams.num_workers, sampler=sampler,
            collate_fn=collate_fn, pin_memory=hparams.pin_memory)

        data_loaders[phase] = data_loader

    // Model
    model = build_model()
    print(model)
    if use_cuda:
        model = model.cuda()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: r9y9/wavenet_vocoder
Commit Name: 3210d75451d5d78ebe884c712f31800ca1ccd0b3
Time: 2018-01-03
Author: zryuichi@gmail.com
File Name: train.py
Class Name:
Method Name:


Project Name: mariogeiger/se3cnn
Commit Name: 49a2ea975624090307c652e91258e5b6f02cda41
Time: 2018-10-25
Author: geiger.mario@gmail.com
File Name: examples/tetris.py
Class Name:
Method Name: main


Project Name: PyMVPA/PyMVPA
Commit Name: a243ad678a264c898e133bb1f97efed72703ee06
Time: 2016-10-02
Author: matteo.visconti.gr@dartmouth.edu
File Name: mvpa2/tests/test_rsa.py
Class Name:
Method Name: test_CDist