24ef297cca3a32b6f73d14d865cee120f97674c5,torchdiffeq/_impl/rk_common.py,,_runge_kutta_step,#,33

Before Change


    k = _UncheckedAssign.apply(k, f0, (..., 0))
    for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)):
        ti = t0 + alpha_i * dt
        yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0).type_as(y0)  // tableau is float 64 so cast back
        f = func(ti, yi)
        k = _UncheckedAssign.apply(k, f, (..., i + 1))

    if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()):
        // This property (true for Dormand-Prince) lets us save a few FLOPs.
        yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0).type_as(y0)  // tableau is float 64 so cast back

    y1 = yi
    f1 = k[..., -1]
    y1_error = k.matmul(dt * tableau.c_error)
    return y1, f1, y1_error, k

After Change


    

    t0 = t0.type_as(y0)
    dt = dt.type_as(y0)

    // We use an unchecked assign to put data into k without incrementing its _version counter, so that the backward
    // doesn"t throw an (overzealous) error about in-place correctness. We know that it"s actually correct.
    k = torch.empty(*f0.shape, len(tableau.alpha) + 1, dtype=y0.dtype, device=y0.device)
    k = _UncheckedAssign.apply(k, f0, (..., 0))
    for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)):
        ti = t0 + alpha_i * dt
        yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0)
        f = func(ti, yi)
        k = _UncheckedAssign.apply(k, f, (..., i + 1))

    if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()):
        // This property (true for Dormand-Prince) lets us save a few FLOPs.
        yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0)

    y1 = yi
    f1 = k[..., -1]
    y1_error = k.matmul(dt * tableau.c_error)
    return y1, f1, y1_error, k

Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: rtqichen/torchdiffeq
Commit Name: 24ef297cca3a32b6f73d14d865cee120f97674c5
Time: 2020-08-03
Author: 33688385+patrick-kidger@users.noreply.github.com
File Name: torchdiffeq/_impl/rk_common.py
Class Name:
Method Name: _runge_kutta_step


Project Name: rtqichen/torchdiffeq
Commit Name: 0a7f9083918008145ee254b92228c8ba7c2f9c56
Time: 2020-08-03
Author: 33688385+patrick-kidger@users.noreply.github.com
File Name: torchdiffeq/_impl/rk_common.py
Class Name:
Method Name: _runge_kutta_step