245450130274f0fc7bd9ea6196ea4097e6298d0d,ml-agents/mlagents/trainers/tests/torch/test_attention.py,,test_simple_transformer_training,#,85
Before Change
init_error = -1.0
for _ in range(250):
center = torch.rand((batch_size, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
with torch.no_grad():
// create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
if init_error == -1.0:
init_error = error.item()
else:
After Change
batch_size = 200
for _ in range(200):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
// create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
print(error.item())
optimizer.zero_grad()
error.backward()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 2
Instances
Project Name: Unity-Technologies/ml-agents
Commit Name: 245450130274f0fc7bd9ea6196ea4097e6298d0d
Time: 2021-01-12
Author: 54679309+andrewcoh@users.noreply.github.com
File Name: ml-agents/mlagents/trainers/tests/torch/test_attention.py
Class Name:
Method Name: test_simple_transformer_training
Project Name: OpenMined/PySyft
Commit Name: e593cad25d6831623e6a2b6d34bcb04adcbe00f9
Time: 2020-08-31
Author: 33666625+abogaziah@users.noreply.github.com
File Name: test/torch/tensors/test_replicated_shared.py
Class Name:
Method Name: test_view