@pytest.mark.parametrize("size", [3, 4, 5])
@pytest.mark.parametrize("backend", ["numpy", torch_if_found])
def test_chain(size, backend):
xs = [np.random.normal(size=(2, 2)) for _ in range(size)]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i:i+2] for i in range(size)]
inputs = ",".join(names)
After Change
@pytest.mark.parametrize("backend", backends)
def test_chain(size, backend):
xs = [np.random.rand(2, 2) for _ in range(size)]
shapes = [x.shape for x in xs]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i:i+2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates():
print(inputs)
for i in range(size + 1):
target = alphabet[i]
eq = "{}->{}".format(inputs, target)
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *shapes)
expr(*xs, backend=backend)
print("-" * 40)