165756cc19a19db0e75c128ccc97eff3579af1c3,examples/nas/enas/search.py,,,#,19

Before Change


    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=accuracy,
                               reward_function=reward_accuracy,
                               optimizer=optimizer,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                               batch_size=args.batch_size,
                               num_epochs=num_epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               log_frequency=args.log_frequency,
                               mutator=mutator)
    if args.visualization:
        trainer.enable_visualization()
    trainer.train()

After Change


    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

    if args.v1:
        trainer = enas.EnasTrainer(model,
                                   loss=criterion,
                                   metrics=accuracy,
                                   reward_function=reward_accuracy,
                                   optimizer=optimizer,
                                   callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                                   batch_size=args.batch_size,
                                   num_epochs=num_epochs,
                                   dataset_train=dataset_train,
                                   dataset_valid=dataset_valid,
                                   log_frequency=args.log_frequency,
                                   mutator=mutator)
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
    else:
        from nni.retiarii.trainer.pytorch.enas import EnasTrainer
        trainer = EnasTrainer(model,
                              loss=criterion,
                              metrics=accuracy,
                              reward_function=reward_accuracy,
                              optimizer=optimizer,
                              batch_size=args.batch_size,
                              num_epochs=num_epochs,
                              dataset=dataset_train,
                              log_frequency=args.log_frequency,
                              ctrl_kwargs=ctrl_kwargs)
        trainer.fit()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: microsoft/nni
Commit Name: 165756cc19a19db0e75c128ccc97eff3579af1c3
Time: 2020-12-05
Author: Yuge.Zhang@microsoft.com
File Name: examples/nas/enas/search.py
Class Name:
Method Name:


Project Name: microsoft/nni
Commit Name: 165756cc19a19db0e75c128ccc97eff3579af1c3
Time: 2020-12-05
Author: Yuge.Zhang@microsoft.com
File Name: examples/nas/darts/search.py
Class Name:
Method Name:


Project Name: ClimbsRocks/auto_ml
Commit Name: 090b20a6dc31cb152505523e5f48c87f6142278c
Time: 2016-08-24
Author: ClimbsBytes@gmail.com
File Name: auto_ml/predictor.py
Class Name: Predictor
Method Name: train