else:
inputs, labels = Variable(inputs), Variable(labels)
// if we use low precision, input also need to be fp16
if fp16:
inputs = inputs.half()
// zero the parameter gradients
optimizer.zero_grad()
// forward
if phase == "val":
with torch.no_grad():
outputs = model(inputs)
else:
outputs = model(inputs)
if not opt.PCB:
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
else:
part = {}
sm = nn.Softmax(dim=1)
num_part = 6
for i in range(num_part):
part[i] = outputs[i]
score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5])
_, preds = torch.max(score.data, 1)
loss = criterion(part[0], labels)
for i in range(num_part-1):
loss += criterion(part[i+1], labels)
// backward + optimize only if in training phase
if phase == "train":
if fp16: // we use optimier to backward loss
optimizer.backward(loss)
else:
loss.backward()
optimizer.step()