6e6f6c993efc6345f0582c32d96c778055f165cb,train_pt.py,,main,#,389
Before Change
if hasattr(net, "module"):
input_image_size = net.module.in_size[0] if hasattr(net.module, "in_size") else args.input_size
else:
input_image_size = net.in_size[0] if hasattr(net, "in_size") else args.input_size
train_data = get_train_data_loader(
data_dir=args.data_dir,
After Change
assert (hasattr(real_net, "num_classes"))
num_classes = real_net.num_classes
ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
ds_metainfo.update(args=args)
train_data = get_train_data_source(
ds_metainfo=ds_metainfo,
batch_size=batch_size,
num_workers=args.num_workers)
val_data = get_val_data_source(
ds_metainfo=ds_metainfo,
batch_size=batch_size,
num_workers=args.num_workers)
optimizer, lr_scheduler, start_epoch = prepare_trainer(
net=net,
optimizer_name=args.optimizer_name,
wd=args.wd,
momentum=args.momentum,
lr_mode=args.lr_mode,
lr=args.lr,
lr_decay_period=args.lr_decay_period,
lr_decay_epoch=args.lr_decay_epoch,
lr_decay=args.lr_decay,
// warmup_epochs=args.warmup_epochs,
// batch_size=batch_size,
num_epochs=args.num_epochs,
// num_training_samples=num_training_samples,
state_file_path=args.resume_state)
if args.save_dir and args.save_interval:
param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + ["Train.Loss", "LR"]
lp_saver = TrainLogParamSaver(
checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label, args.model),
last_checkpoint_file_name_suffix="last",
best_checkpoint_file_name_suffix=None,
last_checkpoint_dir_path=args.save_dir,
best_checkpoint_dir_path=None,
last_checkpoint_file_count=2,
best_checkpoint_file_count=2,
checkpoint_file_save_callback=save_params,
checkpoint_file_exts=(".pth", ".states"),
save_interval=args.save_interval,
num_epochs=args.num_epochs,
param_names=param_names,
acc_ind=ds_metainfo.saver_acc_ind,
// bigger=[True],
// mask=None,
score_log_file_path=os.path.join(args.save_dir, "score.log"),
score_log_attempt_value=args.attempt,
best_map_log_file_path=os.path.join(args.save_dir, "best_map.log"))
else:
lp_saver = None
train_net(
batch_size=batch_size,
num_epochs=args.num_epochs,
start_epoch1=args.start_epoch,
train_data=train_data,
val_data=val_data,
net=net,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
lp_saver=lp_saver,
log_interval=args.log_interval,
num_classes=num_classes,
val_metric=get_composite_metric(ds_metainfo.val_metric_names, ds_metainfo.val_metric_extra_kwargs),
train_metric=get_composite_metric(ds_metainfo.train_metric_names, ds_metainfo.train_metric_extra_kwargs),
opt_metric_name=ds_metainfo.val_metric_names[ds_metainfo.saver_acc_ind],
use_cuda=use_cuda)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 7
Instances
Project Name: osmr/imgclsmob
Commit Name: 6e6f6c993efc6345f0582c32d96c778055f165cb
Time: 2019-05-10
Author: osemery@gmail.com
File Name: train_pt.py
Class Name:
Method Name: main
Project Name: osmr/imgclsmob
Commit Name: 6e6f6c993efc6345f0582c32d96c778055f165cb
Time: 2019-05-10
Author: osemery@gmail.com
File Name: train_pt.py
Class Name:
Method Name: main
Project Name: osmr/imgclsmob
Commit Name: 018dbdad3c999d897abc565058bdf56a15bab5ca
Time: 2019-05-04
Author: osemery@gmail.com
File Name: eval_gl_cifar.py
Class Name:
Method Name: main
Project Name: osmr/imgclsmob
Commit Name: 9766d2c13ddca10d5e45280450ee9cab649afb18
Time: 2019-05-10
Author: osemery@gmail.com
File Name: eval_pt.py
Class Name:
Method Name: main