53633acd7c861fd73e3954088a48d0ac8dc42895,niftynet/application/regression_application.py,RegressionApplication,connect_data_and_network,#RegressionApplication#Any#Any#,180

Before Change


    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict["image"], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:
            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name="crop-88")

After Change




        if self.is_training:
            data_dict, net_out = tf.cond(self.is_validation,
                                         lambda: data_net(False),
                                         lambda: data_net(True))
            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name="crop-88")
            with tf.name_scope("Optimiser"):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                loss_type=self.action_param.loss_type)

            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get("output", None))
            weight_map = None if data_dict.get("weight", None) is None \
                else crop_layer(data_dict.get("weight", None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            // collecting gradients variables
            gradients_collector.add_to_collection([grads])
            // collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name="Loss",
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name="Loss",
                average_over_devices=True, summary_type="scalar",
                collection=TF_SUMMARIES)
        else:
            data_dict, net_out = data_net(for_training=False)
            crop_layer = CropLayer(border=0, name="crop-88")
            post_process_layer = PostProcessingLayer("IDENTITY")
            net_out = post_process_layer(crop_layer(net_out))
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: NifTK/NiftyNet
Commit Name: 53633acd7c861fd73e3954088a48d0ac8dc42895
Time: 2017-11-01
Author: eli.gibson@gmail.com
File Name: niftynet/application/regression_application.py
Class Name: RegressionApplication
Method Name: connect_data_and_network


Project Name: NifTK/NiftyNet
Commit Name: 01f1bcb376dfd967603c785a255f927dea2712b6
Time: 2017-11-15
Author: wenqi.li@ucl.ac.uk
File Name: demos/BRATS17/brats_segmentation.py
Class Name: BRATSApp
Method Name: connect_data_and_network


Project Name: NifTK/NiftyNet
Commit Name: 53633acd7c861fd73e3954088a48d0ac8dc42895
Time: 2017-11-01
Author: eli.gibson@gmail.com
File Name: niftynet/application/regression_application.py
Class Name: RegressionApplication
Method Name: connect_data_and_network


Project Name: NifTK/NiftyNet
Commit Name: dfdad808d0979d6e45419720fa0d73b4cedcbb96
Time: 2017-11-01
Author: eli.gibson@gmail.com
File Name: niftynet/application/segmentation_application.py
Class Name: SegmentationApplication
Method Name: connect_data_and_network