x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.conv3(x)
// capsules squash
x = torch.cat([squash(capsule) for capsule in torch.chunk(x, chunks=128, dim=1)], dim=1)
x = self.conv4(x)
x = torch.cat([squash(capsule) for capsule in torch.chunk(x, chunks=64, dim=1)], dim=1)
x = self.conv5(x)
x = torch.cat([squash(capsule) for capsule in torch.chunk(x, chunks=32, dim=1)], dim=1)
x = self.conv6(x)
After Change
def forward(self, x):
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
// capsules squash
x = torch.cat([squash(capsule) for capsule in torch.chunk(x, chunks=8, dim=1)], dim=1)
x = self.conv6(x)