9ddaacb9e73ac061c08403e16ac1a4d5364dca9b,mushroom_rl/algorithms/value/dqn/rainbow.py,RainbowNetwork,forward,#RainbowNetwork#,37

Before Change


        a_pa = [self._pa[i](features) for i in range(self._n_output)]
        a_pa = torch.stack(a_pa, dim=0)
        mean_a_pa = a_pa.mean(0)
        softmax = [F.softmax(a_pv + a_pa[i] - mean_a_pa, -1) for i in range(self._n_output)]
        softmax = torch.stack(softmax, dim=1)

        if not get_distribution:

After Change


        a_pa = torch.stack(a_pa, dim=1)
        a_pv = a_pv.unsqueeze(1).repeat(1, self._n_output, 1)
        mean_a_pa = a_pa.mean(1, keepdim=True).repeat(1, self._n_output, 1)
        softmax = F.softmax(a_pv + a_pa - mean_a_pa, dim=-1)

        if not get_distribution:
            q = torch.empty(softmax.shape[:-1])
            for i in range(softmax.shape[0]):
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 2

Instances


Project Name: AIRLab-POLIMI/mushroom
Commit Name: 9ddaacb9e73ac061c08403e16ac1a4d5364dca9b
Time: 2021-02-10
Author: carlo.deramo@gmail.com
File Name: mushroom_rl/algorithms/value/dqn/rainbow.py
Class Name: RainbowNetwork
Method Name: forward


Project Name: Kaixhin/Rainbow
Commit Name: a4df4a6f648053e4ad05bbcfc06a87c6fa984e53
Time: 2018-02-15
Author: kaia@fb.com
File Name: model.py
Class Name: DQN
Method Name: forward


Project Name: AIRLab-POLIMI/mushroom
Commit Name: ff55647673a7279b47d1268699839b92aca3e953
Time: 2021-02-10
Author: carlo.deramo@gmail.com
File Name: mushroom_rl/algorithms/value/dqn/rainbow.py
Class Name: RainbowNetwork
Method Name: forward