/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.lsat.common.ludus.backend.games.ratio.solvers.policy;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import org.eclipse.lsat.common.ludus.backend.datastructures.tuple.Quadruple;
import org.eclipse.lsat.common.ludus.backend.datastructures.tuple.Triple;
import org.eclipse.lsat.common.ludus.backend.datastructures.tuple.Tuple;
import org.eclipse.lsat.common.ludus.backend.games.StrategyVector;
import org.eclipse.lsat.common.ludus.backend.games.algorithms.DoubleFunctions;
import org.eclipse.lsat.common.ludus.backend.games.algorithms.GraphChecks;
import org.eclipse.lsat.common.ludus.backend.games.ratio.solvers.policy.RatioGamePolicyIteration;

public class PolicyIterationDoubleVars {
    private static final double MINUS_INFTY = Double.NEGATIVE_INFINITY;

    private PolicyIterationDoubleVars() {
    }

    public static <V, E> Tuple<Map<V, Double>, StrategyVector<V, E>> solve(RatioGamePolicyIteration<V, E, Double> game) {
        return PolicyIterationDoubleVars.solve(game, 1.0E-4, 1.0E-4);
    }

    public static <V, E> Tuple<Map<V, Double>, StrategyVector<V, E>> solve(RatioGamePolicyIteration<V, E, Double> game, Double epsilon, Double delta) {
        if (GraphChecks.checkEachNodeHasSuccessor(game)) {
            StrategyVector initialStrategy = new StrategyVector();
            initialStrategy.initializeRandomStrategy(game);
            return PolicyIterationDoubleVars.policyIteration(game, initialStrategy, epsilon, delta);
        }
        System.out.println("Input game graph is not valid. Not every vertex has a successor.");
        return null;
    }

    public static <V, E> Tuple<Map<V, Double>, StrategyVector<V, E>> solve(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> initialStrategy, Double epsilon, Double delta) {
        if (GraphChecks.checkEachNodeHasSuccessor(game)) {
            return PolicyIterationDoubleVars.policyIteration(game, initialStrategy, epsilon, delta);
        }
        System.out.println("Input game graph is not valid. Not every vertex has a successor.");
        return null;
    }

    private static <V, E> Tuple<Map<V, Double>, StrategyVector<V, E>> policyIteration(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> initialStrategy, Double epsilon, Double delta) {
        Set vertices = game.getVertices();
        boolean improvement = true;
        Map distVector = PolicyIterationDoubleVars.initializeVector(vertices, Double.POSITIVE_INFINITY);
        Map dw2vector = PolicyIterationDoubleVars.initializeVector(vertices, Double.POSITIVE_INFINITY);
        Map ratioVector = PolicyIterationDoubleVars.initializeVector(vertices, Double.NEGATIVE_INFINITY);
        StrategyVector currentStrategy = initialStrategy;
        while (improvement) {
            improvement = false;
            Quadruple<Map<V, Double>, Map<V, Double>, StrategyVector<V, E>, Map<V, Double>> result = PolicyIterationDoubleVars.improveStrategyPlayer1(game, currentStrategy, distVector, ratioVector, dw2vector, epsilon, delta);
            distVector = result.getLeft();
            ratioVector = result.getMiddleLeft();
            currentStrategy = result.getMiddleRight();
            dw2vector = result.getRight();
            for (Object v : game.getV0()) {
                for (Object e : game.outgoingEdgesOf(v)) {
                    Object u = game.getEdgeTarget(e);
                    double mw = ratioVector.get(u);
                    double w1 = (Double)game.getWeight1(e);
                    double w2 = (Double)game.getWeight2(e);
                    double reweighted = w1 - mw * w2;
                    Double dv = distVector.get(v);
                    Double du = distVector.get(u) + reweighted;
                    if (!DoubleFunctions.lessThan(ratioVector.get(v), ratioVector.get(u), epsilon) && (!DoubleFunctions.equalTo(ratioVector.get(v), ratioVector.get(u), epsilon) || !DoubleFunctions.lessThan(dv, du, delta))) continue;
                    currentStrategy.setSuccessor(v, u);
                    improvement = true;
                }
            }
        }
        return Tuple.of(ratioVector, currentStrategy);
    }

    private static <V, E> Quadruple<Map<V, Double>, Map<V, Double>, StrategyVector<V, E>, Map<V, Double>> improveStrategyPlayer1(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> currentStrategy, Map<V, Double> d_prev, Map<V, Double> r_prev, Map<V, Double> dw2_prev, Double epsilon, Double delta) {
        boolean improvement = true;
        Map<V, Double> d_i_t = new HashMap<V, Double>(d_prev);
        Map<V, Double> r_i_t = new HashMap<V, Double>(r_prev);
        Map<V, Double> dw2_i_t = new HashMap<V, Double>(dw2_prev);
        StrategyVector s_i_t = new StrategyVector(currentStrategy);
        while (improvement) {
            improvement = false;
            Triple<Map<V, Double>, Map<V, Double>, Map<V, Double>> evalResult = PolicyIterationDoubleVars.evaluateStrategy(game, s_i_t, d_prev, r_prev, dw2_prev, epsilon);
            d_i_t = evalResult.getLeft();
            r_i_t = evalResult.getMiddle();
            dw2_i_t = evalResult.getRight();
            for (Object v : game.getV1()) {
                for (Object e : game.outgoingEdgesOf(v)) {
                    Object u = game.getEdgeTarget(e);
                    double cycleRatio = r_i_t.get(u);
                    double w1 = (Double)game.getWeight1(e);
                    double w2 = (Double)game.getWeight2(e);
                    double reweighted = w1 - cycleRatio * w2;
                    Double dv = d_i_t.get(v);
                    Double du = d_i_t.get(u) + reweighted;
                    if (!DoubleFunctions.greaterThan(r_i_t.get(v), r_i_t.get(u), epsilon) && (!DoubleFunctions.equalTo(r_i_t.get(v), r_i_t.get(u), epsilon) || !DoubleFunctions.greaterThan(dv, du, delta))) continue;
                    s_i_t.setSuccessor(v, u);
                    improvement = true;
                }
            }
        }
        return Quadruple.of(d_i_t, r_i_t, s_i_t, dw2_i_t);
    }

    private static <V, E> Triple<Map<V, Double>, Map<V, Double>, Map<V, Double>> evaluateStrategy(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> strategy, Map<V, Double> distanceVector, Map<V, Double> ratioVector, Map<V, Double> dw2, Double epsilon) {
        Tuple<Set<V>, Map<V, Double>> cycleResult = PolicyIterationDoubleVars.findCyclesInRestrictedGraph(game, strategy);
        Map<V, Double> r_i_t = cycleResult.getRight();
        Triple<Map<V, Double>, Map<V, Double>, Map<V, Double>> cd = PolicyIterationDoubleVars.computeDistances(game, strategy, (Collection)cycleResult.getLeft(), r_i_t, distanceVector, ratioVector, dw2, epsilon);
        Map<V, Double> d_i_t = cd.getLeft();
        r_i_t = cd.getMiddle();
        Map<V, Double> dw2_i_t = cd.getRight();
        return Triple.of(d_i_t, r_i_t, dw2_i_t);
    }

    private static <V, E> Tuple<Set<V>, Map<V, Double>> findCyclesInRestrictedGraph(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> currentStrategy) {
        Object BOTTOM_VERTEX = null;
        HashSet selectedVertices = new HashSet();
        Map visited = PolicyIterationDoubleVars.initializeVector(game.getVertices(), BOTTOM_VERTEX);
        HashMap r_i_t = new HashMap();
        for (Object v : game.getVertices()) {
            if (visited.get(v) != BOTTOM_VERTEX) continue;
            Object u = v;
            while (visited.get(u) == BOTTOM_VERTEX) {
                visited.put(u, v);
                u = currentStrategy.getSuccessor(u);
            }
            if (visited.get(u) != v) continue;
            Object v_s = u;
            V x = currentStrategy.getSuccessor(u);
            Object e = game.getEdge(u, currentStrategy.getSuccessor(u));
            double w1sum = (Double)game.getWeight1(e);
            double w2sum = (Double)game.getWeight2(e);
            while (x != u) {
                if (game.getId(x) < game.getId(v_s)) {
                    v_s = x;
                }
                Object x_sucx = game.getEdge(x, currentStrategy.getSuccessor(x));
                double w1 = (Double)game.getWeight1(x_sucx);
                double w2 = (Double)game.getWeight2(x_sucx);
                w1sum += w1;
                w2sum += w2;
                x = currentStrategy.getSuccessor(x);
            }
            r_i_t.put(v_s, w1sum / w2sum);
            selectedVertices.add(v_s);
        }
        return Tuple.of(selectedVertices, r_i_t);
    }

    private static <V, E> Triple<Map<V, Double>, Map<V, Double>, Map<V, Double>> computeDistances(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> currentStrategy, Collection<V> selectedVertices, Map<V, Double> r_i_t, Map<V, Double> d_prev, Map<V, Double> r_prev, Map<V, Double> dw2_prev, Double epsilon) {
        Stack<V> stack = new Stack<V>();
        Map<Boolean, Boolean> visited = PolicyIterationDoubleVars.initializeVector(game.getVertices(), false);
        HashMap<Object, Double> d_i_t = new HashMap<Object, Double>();
        HashMap<Object, Double> dw2 = new HashMap<Object, Double>();
        for (V u : selectedVertices) {
            if (DoubleFunctions.equalTo(r_i_t.get(u), r_prev.get(u), epsilon)) {
                d_i_t.put(u, d_prev.get(u));
                dw2.put(u, dw2_prev.get(u));
            } else {
                d_i_t.put(u, 0.0);
                dw2.put(u, 0.0);
            }
            visited.put((Boolean)u, true);
        }
        for (V v : game.getVertices()) {
            if (visited.get(v).booleanValue()) continue;
            Object u = v;
            while (!visited.get(u).booleanValue()) {
                visited.put((Boolean)u, true);
                stack.push(u);
                u = currentStrategy.getSuccessor(u);
            }
            while (!stack.isEmpty()) {
                Object x = stack.pop();
                Object e = game.getEdge(x, u);
                double w1 = (Double)game.getWeight1(e);
                double w2 = (Double)game.getWeight2(e);
                double cycleRatio = r_i_t.get(u);
                double reweighted = w1 - cycleRatio * w2;
                r_i_t.put((Double)x, cycleRatio);
                d_i_t.put(x, (Double)d_i_t.get(u) + reweighted);
                dw2.put(x, (Double)dw2.get(u) + w2);
                u = x;
            }
        }
        return Triple.of(d_i_t, r_i_t, dw2);
    }

    private static <V, T> Map<V, T> initializeVector(Collection<V> vertices, T value) {
        HashMap vector = new HashMap();
        vertices.forEach(v -> {
            Object object2 = vector.put(v, value);
        });
        return vector;
    }

    private static <V, E> String printVector(RatioGamePolicyIteration<V, E, Double> game, Map<V, Double> values) {
        StringBuilder output = new StringBuilder();
        for (V v : values.keySet()) {
            output.append(game.getId(v)).append(": ").append(values.get(v)).append(", ");
        }
        return output.toString();
    }

    private static <V, E> String printStrategy(RatioGamePolicyIteration<V, E, Double> game, StrategyVector<V, E> strategy) {
        return strategy.getVertices().stream().map(v -> String.valueOf(game.getId(v).toString()) + "-->" + game.getId(strategy.getSuccessor(v)).toString()).collect(Collectors.joining(", ", "{", "}"));
    }
}

