97d7f88df32ae58875d7534c7c517905a3eb1341,deepchem/models/torch_models/gat.py,GAT,forward,#GAT#Any#,113

Before Change


      If mode == "classification", the shape is `(batch_size, n_tasks, n_classes)` (n_tasks > 1)
      or `(batch_size, n_classes)` (n_tasks == 1) and the output values are probabilities of each class label.
    
    node_feat, edge_index = data.x, data.edge_index
    node_feat = self.embedding(node_feat)

    // convolutional layer
    for conv in self.conv_layers:
      node_feat = conv(node_feat, edge_index)

    // pooling
    graph_feat = self.pooling(node_feat, data.batch)
    graph_feat = F.leaky_relu(self.fc(graph_feat))
    out = self.out(graph_feat)

    if self.mode == "regression":
      return out
    else:
      logits = out.view(-1, self.n_tasks, self.n_classes)
      // for n_tasks == 1 case
      logits = torch.squeeze(logits)
      proba = F.softmax(logits, dim=-1)
      return proba, logits

After Change


            This is only returned when self.mode = "classification", the output consists of the
            logits for classes before softmax.
        
    node_feats = g.ndata[self.nfeat_name]
    out = self.model(g, node_feats)

    if self.mode == "classification":
      if self.n_tasks == 1:
        logits = out.view(-1, self.n_classes)
        softmax_dim = 1
      else:
        logits = out.view(-1, self.n_tasks, self.n_classes)
        softmax_dim = 2
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 8

Instances


Project Name: deepchem/deepchem
Commit Name: 97d7f88df32ae58875d7534c7c517905a3eb1341
Time: 2020-11-05
Author: mufeili1996@gmail.com
File Name: deepchem/models/torch_models/gat.py
Class Name: GAT
Method Name: forward


Project Name: tryolabs/luminoth
Commit Name: 8dd1ec5bfef42af8ad75a38a4ac6d58e7d7763c9
Time: 2017-10-09
Author: javirey@gmail.com
File Name: luminoth/utils/image_vis.py
Class Name:
Method Name: draw_rpn_bbox_pred


Project Name: deepchem/deepchem
Commit Name: f766c92a63fb67a99744dcb82c8c78406fbf905a
Time: 2020-11-04
Author: mufeili1996@gmail.com
File Name: deepchem/models/torch_models/gat.py
Class Name: GAT
Method Name: forward