6e6f6c993efc6345f0582c32d96c778055f165cb,train_pt.py,,main,#,389

Before Change


    if hasattr(net, "module"):
        input_image_size = net.module.in_size[0] if hasattr(net.module, "in_size") else args.input_size
    else:
        input_image_size = net.in_size[0] if hasattr(net, "in_size") else args.input_size

    train_data = get_train_data_loader(
        data_dir=args.data_dir,

After Change


    assert (hasattr(real_net, "num_classes"))
    num_classes = real_net.num_classes

    ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
    ds_metainfo.update(args=args)

    train_data = get_train_data_source(
        ds_metainfo=ds_metainfo,
        batch_size=batch_size,
        num_workers=args.num_workers)
    val_data = get_val_data_source(
        ds_metainfo=ds_metainfo,
        batch_size=batch_size,
        num_workers=args.num_workers)

    optimizer, lr_scheduler, start_epoch = prepare_trainer(
        net=net,
        optimizer_name=args.optimizer_name,
        wd=args.wd,
        momentum=args.momentum,
        lr_mode=args.lr_mode,
        lr=args.lr,
        lr_decay_period=args.lr_decay_period,
        lr_decay_epoch=args.lr_decay_epoch,
        lr_decay=args.lr_decay,
        // warmup_epochs=args.warmup_epochs,
        // batch_size=batch_size,
        num_epochs=args.num_epochs,
        // num_training_samples=num_training_samples,
        state_file_path=args.resume_state)

    if args.save_dir and args.save_interval:
        param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + ["Train.Loss", "LR"]
        lp_saver = TrainLogParamSaver(
            checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label, args.model),
            last_checkpoint_file_name_suffix="last",
            best_checkpoint_file_name_suffix=None,
            last_checkpoint_dir_path=args.save_dir,
            best_checkpoint_dir_path=None,
            last_checkpoint_file_count=2,
            best_checkpoint_file_count=2,
            checkpoint_file_save_callback=save_params,
            checkpoint_file_exts=(".pth", ".states"),
            save_interval=args.save_interval,
            num_epochs=args.num_epochs,
            param_names=param_names,
            acc_ind=ds_metainfo.saver_acc_ind,
            // bigger=[True],
            // mask=None,
            score_log_file_path=os.path.join(args.save_dir, "score.log"),
            score_log_attempt_value=args.attempt,
            best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
    else:
        lp_saver = None

    train_net(
        batch_size=batch_size,
        num_epochs=args.num_epochs,
        start_epoch1=args.start_epoch,
        train_data=train_data,
        val_data=val_data,
        net=net,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lp_saver=lp_saver,
        log_interval=args.log_interval,
        num_classes=num_classes,
        val_metric=get_composite_metric(ds_metainfo.val_metric_names, ds_metainfo.val_metric_extra_kwargs),
        train_metric=get_composite_metric(ds_metainfo.train_metric_names, ds_metainfo.train_metric_extra_kwargs),
        opt_metric_name=ds_metainfo.val_metric_names[ds_metainfo.saver_acc_ind],
        use_cuda=use_cuda)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: osmr/imgclsmob
Commit Name: 6e6f6c993efc6345f0582c32d96c778055f165cb
Time: 2019-05-10
Author: osemery@gmail.com
File Name: train_pt.py
Class Name:
Method Name: main


Project Name: osmr/imgclsmob
Commit Name: 6e6f6c993efc6345f0582c32d96c778055f165cb
Time: 2019-05-10
Author: osemery@gmail.com
File Name: train_pt.py
Class Name:
Method Name: main


Project Name: osmr/imgclsmob
Commit Name: 018dbdad3c999d897abc565058bdf56a15bab5ca
Time: 2019-05-04
Author: osemery@gmail.com
File Name: eval_gl_cifar.py
Class Name:
Method Name: main


Project Name: osmr/imgclsmob
Commit Name: 9766d2c13ddca10d5e45280450ee9cab649afb18
Time: 2019-05-10
Author: osemery@gmail.com
File Name: eval_pt.py
Class Name:
Method Name: main