/*
 * Decompiled with CFR 0.152.
 */
package explicit;

import common.IterableBitSet;
import common.StopWatch;
import common.iterable.FunctionalPrimitiveIterator;
import explicit.IncomingChoiceRelation;
import explicit.MDP;
import explicit.rewards.MDPRewards;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.function.IntPredicate;
import prism.PrismComponent;

public class DijkstraSweepMPI {
    private static boolean debug = false;
    private MDP<Double> mdp;
    private MDPRewards<Double> rewards;
    private PriorityQueue<QueueEntry> queue;
    private double[] pState;
    private double[] wState;
    private HashMap<IncomingChoiceRelation.Choice, ChoiceValues> choiceValues = new HashMap();
    private QueueEntry[] pri;
    private int[] pi;
    private BitSet unknown;
    private BitSet target;
    private BitSet fin = new BitSet();
    private IncomingChoiceRelation incoming;
    private double lambda;

    private DijkstraSweepMPI(PrismComponent prismComponent, MDP<Double> mDP, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        int n;
        int n2;
        this.mdp = mDP;
        this.unknown = bitSet2;
        this.target = bitSet;
        this.rewards = mDPRewards;
        this.incoming = IncomingChoiceRelation.forModel(prismComponent, mDP);
        this.queue = new PriorityQueue();
        this.pState = new double[mDP.getNumStates()];
        this.wState = new double[mDP.getNumStates()];
        this.pri = new QueueEntry[mDP.getNumStates()];
        this.pi = new int[mDP.getNumStates()];
        Object object = IterableBitSet.getSetBits(bitSet2).iterator();
        while (object.hasNext()) {
            n2 = (Integer)object.next();
            int n3 = mDP.getNumChoices(n2);
            for (n = 0; n < n3; ++n) {
                IncomingChoiceRelation.Choice choice = new IncomingChoiceRelation.Choice(n2, n);
                double d = mDPRewards.getStateReward(n2);
                this.choiceValues.put(choice, new ChoiceValues(0.0, d += mDPRewards.getTransitionReward(n2, n).doubleValue()));
            }
        }
        object = IterableBitSet.getSetBits(bitSet).iterator();
        while (object.hasNext()) {
            n2 = (Integer)object.next();
            this.pState[n2] = 1.0;
        }
        object = new HashSet();
        FunctionalPrimitiveIterator.OfInt ofInt = IterableBitSet.getSetBits(bitSet).iterator();
        while (ofInt.hasNext()) {
            n = (Integer)ofInt.next();
            for (IncomingChoiceRelation.Choice choice : this.incoming.getIncomingChoices(n)) {
                boolean bl = ((HashSet)object).add(choice);
                if (!bl || !bitSet2.get(choice.getState()) || !this.validChoice(choice)) continue;
                this.update(choice, bitSet);
            }
        }
        ((HashSet)object).clear();
        this.sweep();
        this.computeLambda();
    }

    private void sweep() {
        while (!this.queue.isEmpty()) {
            int n = this.queue.poll().y;
            if (this.fin.get(n)) continue;
            this.fin.set(n);
            ChoiceValues choiceValues = this.choiceValues.get(new IncomingChoiceRelation.Choice(n, this.pi[n]));
            this.wState[n] = choiceValues.w;
            this.pState[n] = choiceValues.p;
            for (IncomingChoiceRelation.Choice choice : this.incoming.getIncomingChoices(n)) {
                if (this.fin.get(choice.getState()) || !this.unknown.get(choice.getState()) || !this.validChoice(choice)) continue;
                this.update(choice, n);
            }
        }
    }

    private boolean validChoice(IncomingChoiceRelation.Choice choice) {
        IntPredicate intPredicate = n -> !this.unknown.get(n) && !this.target.get(n);
        return !this.mdp.someSuccessorsMatch(choice.getState(), choice.getChoice(), intPredicate);
    }

    private void update(IncomingChoiceRelation.Choice choice, int n) {
        double d = this.wState[n];
        double d3 = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (n2, n3, d2) -> {
            if (n3 != n) {
                return 0.0;
            }
            return d2 * d;
        });
        double d4 = this.pState[n];
        double d5 = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (n2, n3, d2) -> {
            if (n3 != n) {
                return 0.0;
            }
            return d2 * d4;
        });
        ChoiceValues choiceValues = this.choiceValues.get(choice);
        assert (choiceValues != null);
        choiceValues.p += d5;
        choiceValues.w += d3;
        QueueEntry queueEntry = new QueueEntry(choice.getState(), 1.0 - choiceValues.p, choiceValues.w);
        if (this.pri[choice.getState()] == null || queueEntry.compareTo(this.pri[choice.getState()]) < 0) {
            this.pri[choice.getState()] = queueEntry;
            this.pi[choice.getState()] = choice.getChoice();
            this.queue.add(queueEntry);
        }
    }

    private void update(IncomingChoiceRelation.Choice choice, BitSet bitSet) {
        double d2 = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (n, n2, d) -> {
            if (bitSet.get(n2)) {
                return d;
            }
            return 0.0;
        });
        ChoiceValues choiceValues = this.choiceValues.get(choice);
        choiceValues.p += d2;
        QueueEntry queueEntry = new QueueEntry(choice.getState(), 1.0 - choiceValues.p, choiceValues.w);
        if (this.pri[choice.getState()] == null || queueEntry.compareTo(this.pri[choice.getState()]) < 0) {
            this.pri[choice.getState()] = queueEntry;
            this.pi[choice.getState()] = choice.getChoice();
            this.queue.add(queueEntry);
        }
    }

    private double computeLambda() {
        this.lambda = 0.0;
        FunctionalPrimitiveIterator.OfInt ofInt = IterableBitSet.getSetBits(this.unknown).iterator();
        while (ofInt.hasNext()) {
            int n3 = (Integer)ofInt.next();
            int n4 = this.pi[n3];
            double d2 = Double.POSITIVE_INFINITY;
            double d3 = this.mdp.sumOverDoubleTransitions(n3, n4, (n, n2, d) -> d * this.pState[n2]);
            if (this.pState[n3] < d3) {
                double d4 = this.rewards.getStateReward(n3) + this.rewards.getTransitionReward(n3, n4);
                d4 += this.mdp.sumOverDoubleTransitions(n3, n4, (n, n2, d) -> d * this.wState[n2]);
                double d5 = this.mdp.sumOverDoubleTransitions(n3, n4, (n, n2, d) -> d * this.pState[n2]);
                d2 = (d4 -= this.wState[n3]) / (d5 -= this.pState[n3]);
            } else {
                d2 = 0.0;
            }
            this.lambda = Double.max(this.lambda, d2);
        }
        return this.lambda;
    }

    public static double[] computeUpperBounds(PrismComponent prismComponent, MDP<Double> mDP, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        StopWatch stopWatch = new StopWatch(prismComponent.getLog());
        stopWatch.start("computing upper bound(s) for Rmin using the DSI-MP algorithm");
        prismComponent.getLog().println("Computing upper bound(s) for Rmin using the Dijkstra Sweep for Monotone Pessimistic Initialization algorithm...");
        double[] dArray = new double[mDP.getNumStates()];
        DijkstraSweepMPI dijkstraSweepMPI = new DijkstraSweepMPI(prismComponent, mDP, mDPRewards, bitSet, bitSet2);
        FunctionalPrimitiveIterator.OfInt ofInt = IterableBitSet.getSetBits(bitSet2).iterator();
        while (ofInt.hasNext()) {
            int n = (Integer)ofInt.next();
            dArray[n] = dijkstraSweepMPI.wState[n] + dijkstraSweepMPI.lambda * (1.0 - dijkstraSweepMPI.pState[n]);
        }
        if (debug) {
            prismComponent.getLog().println(dArray);
        }
        stopWatch.stop();
        return dArray;
    }

    public static double computeUpperBound(PrismComponent prismComponent, MDP<Double> mDP, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        double d = 0.0;
        double[] dArray = DijkstraSweepMPI.computeUpperBounds(prismComponent, mDP, mDPRewards, bitSet, bitSet2);
        FunctionalPrimitiveIterator.OfInt ofInt = IterableBitSet.getSetBits(bitSet2).iterator();
        while (ofInt.hasNext()) {
            int n = (Integer)ofInt.next();
            d = Double.max(d, dArray[n]);
        }
        return d;
    }

    private static class QueueEntry
    implements Comparable<QueueEntry> {
        public int y;
        public double p;
        public double w;

        public QueueEntry(int n, double d, double d2) {
            this.y = n;
            this.p = d;
            this.w = d2;
        }

        @Override
        public int compareTo(QueueEntry queueEntry) {
            int n = Double.compare(this.p, queueEntry.p);
            if (n == 0) {
                return Double.compare(this.w, queueEntry.w);
            }
            return n;
        }
    }

    private static class ChoiceValues {
        public double p;
        public double w;

        public ChoiceValues(double d, double d2) {
            this.p = d;
            this.w = d2;
        }
    }
}

