dc6ae9f6e3858510a5df96592fac3a3a17590b2a,mpc/pnqp.py,,pnqp,#,5

Before Change


            x_init = -(1./H.squeeze(2))*q
        else:
            H_lu = H.lu()
            x_init = -q.lu_solve(*H_lu) // Clamped in the x assignment.
    else:
        x_init = x_init.clone() // Don"t over-write the original x_init.

    x = util.eclamp(x_init, lower, upper)

    // Active examples in the batch.
    J = torch.ones(n_batch).type_as(x).byte()

    for i in range(n_iter):
        g = util.bmv(H, x) + q
        Ic = ((x == lower) & (g > 0)) | ((x == upper) & (g < 0))
        If = 1-Ic

        if If.is_cuda:
            Hff_I = util.bger(If.float(), If.float()).type_as(If)
            not_Hff_I = 1-Hff_I
            Hfc_I = util.bger(If.float(), Ic.float()).type_as(If)
        else:
            Hff_I = util.bger(If, If)
            not_Hff_I = 1-Hff_I
            Hfc_I = util.bger(If, Ic)

        g_ = g.clone()
        g_[Ic] = 0.
        H_ = H.clone()
        H_[not_Hff_I] = 0.0
        H_ += pnqp_I

        if n == 1:
            dx = -(1./H_.squeeze(2))*g_
        else:
            H_lu_ = H_.lu()
            dx = -g_.lu_solve(*H_lu_)

        J = torch.norm(dx, 2, 1) >= 1e-4
        m = J.sum().item() // Number of active examples in the batch.
        if m == 0:
            return x, H_ if n == 1 else H_lu_, If, i

        alpha = torch.ones(n_batch).type_as(x)
        decay = 0.1
        max_armijo = GAMMA
        count = 0
        while max_armijo <= GAMMA and count < 10:
            // Crude way of making sure too much time isn"t being spent
            // doing the line search.
            // assert count < 10

            maybe_x = util.eclamp(x+torch.diag(alpha).mm(dx), lower, upper)
            armijos = (GAMMA+1e-6)*torch.ones(n_batch).type_as(x)
            armijos[J] = (obj(x)-obj(maybe_x))[J]/util.bdot(g, x-maybe_x)[J]
            I = armijos <= GAMMA
            alpha[I] *= decay
            max_armijo = torch.max(armijos)
            count += 1

        x = maybe_x

    // TODO: Maybe change this to a warning.
    print("[WARNING] pnqp warning: Did not converge")

After Change


            x_init = -(1./H.squeeze(2))*q
        else:
            H_lu = H.lu()
            x_init = -q.unsqueeze(2).lu_solve(*H_lu).squeeze(2) // Clamped in the x assignment.
    else:
        x_init = x_init.clone() // Don"t over-write the original x_init.

    x = util.eclamp(x_init, lower, upper)

    // Active examples in the batch.
    J = torch.ones(n_batch).type_as(x).byte()

    for i in range(n_iter):
        g = util.bmv(H, x) + q

        // TODO: Could clean up the types here.
        Ic = (((x == lower) & (g > 0)) | ((x == upper) & (g < 0))).float()
        If = 1-Ic

        if If.is_cuda:
            Hff_I = util.bger(If.float(), If.float()).type_as(If)
            not_Hff_I = 1-Hff_I
            Hfc_I = util.bger(If.float(), Ic.float()).type_as(If)
        else:
            Hff_I = util.bger(If, If)
            not_Hff_I = 1-Hff_I
            Hfc_I = util.bger(If, Ic)

        g_ = g.clone()
        g_[Ic.bool()] = 0.
        H_ = H.clone()
        H_[not_Hff_I.bool()] = 0.0
        H_ += pnqp_I

        if n == 1:
            dx = -(1./H_.squeeze(2))*g_
        else:
            H_lu_ = H_.lu()
            dx = -g_.unsqueeze(2).lu_solve(*H_lu_).squeeze(2)

        J = torch.norm(dx, 2, 1) >= 1e-4
        m = J.sum().item() // Number of active examples in the batch.
        if m == 0:
            return x, H_ if n == 1 else H_lu_, If, i

        alpha = torch.ones(n_batch).type_as(x)
        decay = 0.1
        max_armijo = GAMMA
        count = 0
        while max_armijo <= GAMMA and count < 10:
            // Crude way of making sure too much time isn"t being spent
            // doing the line search.
            // assert count < 10

            maybe_x = util.eclamp(x+torch.diag(alpha).mm(dx), lower, upper)
            armijos = (GAMMA+1e-6)*torch.ones(n_batch).type_as(x)
            armijos[J] = (obj(x)-obj(maybe_x))[J]/util.bdot(g, x-maybe_x)[J]
            I = armijos <= GAMMA
            alpha[I] *= decay
            max_armijo = torch.max(armijos)
            count += 1

        x = maybe_x

    // TODO: Maybe change this to a warning.
    print("[WARNING] pnqp warning: Did not converge")
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 4

Instances


Project Name: locuslab/mpc.pytorch
Commit Name: dc6ae9f6e3858510a5df96592fac3a3a17590b2a
Time: 2019-09-07
Author: brandon.amos.cs@gmail.com
File Name: mpc/pnqp.py
Class Name:
Method Name: pnqp


Project Name: locuslab/mpc.pytorch
Commit Name: dc6ae9f6e3858510a5df96592fac3a3a17590b2a
Time: 2019-09-07
Author: brandon.amos.cs@gmail.com
File Name: mpc/lqr_step.py
Class Name: LQRStep
Method Name: lqr_backward