245450130274f0fc7bd9ea6196ea4097e6298d0d,ml-agents/mlagents/trainers/tests/torch/test_attention.py,,test_simple_transformer_training,#,85

Before Change


    init_error = -1.0
    for _ in range(250):
        center = torch.rand((batch_size, size)) * point_range * 2 - point_range
        key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
        with torch.no_grad():
            // create the target : The key closest to the query in euclidean distance
            distance = torch.sum(
                (center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
            )
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target) ** 2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:

After Change


    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            // create the target : The key closest to the query in euclidean distance
            distance = torch.sum(
                (center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
            )
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target) ** 2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 2

Instances


Project Name: Unity-Technologies/ml-agents
Commit Name: 245450130274f0fc7bd9ea6196ea4097e6298d0d
Time: 2021-01-12
Author: 54679309+andrewcoh@users.noreply.github.com
File Name: ml-agents/mlagents/trainers/tests/torch/test_attention.py
Class Name:
Method Name: test_simple_transformer_training


Project Name: OpenMined/PySyft
Commit Name: e593cad25d6831623e6a2b6d34bcb04adcbe00f9
Time: 2020-08-31
Author: 33666625+abogaziah@users.noreply.github.com
File Name: test/torch/tensors/test_replicated_shared.py
Class Name:
Method Name: test_view