/*
 * Decompiled with CFR 0.152.
 */
package symbolic.build;

import common.SafeCast;
import explicit.CTMCSimple;
import explicit.DTMC;
import explicit.DTMCSimple;
import explicit.Distribution;
import explicit.MDP;
import explicit.MDPSimple;
import explicit.Model;
import explicit.ModelExplicit;
import explicit.SuccessorsIterator;
import explicit.rewards.Rewards;
import explicit.rewards.Rewards2RewardGenerator;
import explicit.rewards.RewardsExplicit;
import explicit.rewards.RewardsSimple;
import io.IOUtils;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.function.IntConsumer;
import jdd.JDD;
import jdd.JDDNode;
import jdd.JDDVars;
import odd.ODDNode;
import prism.Evaluator;
import prism.Pair;
import prism.PrismComponent;
import prism.PrismException;
import prism.RewardGenerator;
import prism.RewardInfo;
import symbolic.model.NondetModel;
import symbolic.model.ProbModel;
import symbolic.states.StateListMTBDD;

public class MTBDD2ExplicitModel
extends PrismComponent {
    public MTBDD2ExplicitModel() {
    }

    public MTBDD2ExplicitModel(PrismComponent prismComponent) {
        super(prismComponent);
    }

    public Model<Double> convertModel(symbolic.model.Model model) throws PrismException {
        ModelExplicit modelExplicit;
        int n = model.getNumStates();
        switch (model.getModelType()) {
            case DTMC: {
                DTMCSimple<Double> dTMCSimple = new DTMCSimple<Double>(n);
                modelExplicit = dTMCSimple;
                modelExplicit.setActions(model.getActions());
                this.convertMarkovChainTransitions((ProbModel)model, dTMCSimple);
                break;
            }
            case CTMC: {
                CTMCSimple<Double> cTMCSimple = new CTMCSimple<Double>(n);
                modelExplicit = cTMCSimple;
                modelExplicit.setActions(model.getActions());
                this.convertMarkovChainTransitions((ProbModel)model, cTMCSimple);
                break;
            }
            case MDP: {
                MDPSimple<Double> mDPSimple = new MDPSimple<Double>(n);
                modelExplicit = mDPSimple;
                modelExplicit.setActions(model.getActions());
                this.convertMDPTransitions((NondetModel)model, mDPSimple);
                break;
            }
            default: {
                throw new PrismException("Can't do symbolic-explicit conversion for " + String.valueOf((Object)model.getModelType()) + "s");
            }
        }
        this.traverseStatesBDD(model.getStart(), model.getAllDDRowVars(), model.getODD(), modelExplicit::addInitialState);
        for (String string : model.getLabels()) {
            BitSet bitSet = new BitSet();
            this.traverseStatesBDD(model.getLabelDD(string), model.getAllDDRowVars(), model.getODD(), bitSet::set);
            modelExplicit.addLabel(string, bitSet);
        }
        if (model.getDeadlocks() != null) {
            BitSet bitSet = new BitSet();
            this.traverseStatesBDD(model.getDeadlocks(), model.getAllDDRowVars(), model.getODD(), bitSet::set);
            int n2 = bitSet.nextSetBit(0);
            while (n2 >= 0) {
                modelExplicit.addDeadlockState(n2);
                n2 = bitSet.nextSetBit(n2 + 1);
            }
        }
        modelExplicit.setStatesList(((StateListMTBDD)model.getReachableStates()).getAsListOfStates());
        return modelExplicit;
    }

    public RewardGenerator<Double> getRewardConverter(final symbolic.model.Model model, final Model<Double> model2, RewardInfo rewardInfo) throws PrismException {
        return new Rewards2RewardGenerator<Double>(rewardInfo, model2, Evaluator.forDouble()){

            @Override
            public Rewards<Double> getTheRewardObject(int n) throws PrismException {
                return MTBDD2ExplicitModel.this.convertRewards(model, model2, n, this.rewardInfo);
            }
        };
    }

    public Rewards<Double> convertRewards(symbolic.model.Model model, Model<Double> model2, int n, RewardInfo rewardInfo) throws PrismException {
        JDDNode jDDNode;
        RewardsSimple<Double> rewardsSimple = new RewardsSimple<Double>(model.getNumStates());
        if (rewardInfo.rewardStructHasStateRewards(n)) {
            jDDNode = model.getStateRewards(n);
            this.traverseVectorDD(jDDNode, model.getAllDDRowVars(), model.getODD(), rewardsSimple::addToStateReward);
        }
        if (rewardInfo.rewardStructHasTransitionRewards(n)) {
            jDDNode = model.getTransRewards(n);
            switch (model.getModelType()) {
                case DTMC: 
                case CTMC: {
                    this.convertMarkovChainTransitionRewards((ProbModel)model, (DTMC)model2, jDDNode, rewardsSimple);
                    break;
                }
                case MDP: {
                    this.convertMDPTransitionRewards((NondetModel)model, (MDP)model2, jDDNode, rewardsSimple);
                    break;
                }
                default: {
                    throw new PrismException("Can't do symbolic-explicit reward conversion for " + String.valueOf((Object)model.getModelType()) + "s");
                }
            }
        }
        return rewardsSimple;
    }

    private void convertMarkovChainTransitions(ProbModel probModel, DTMCSimple<Double> dTMCSimple) throws PrismException {
        List<Object> list = probModel.getActions();
        JDDNode[] jDDNodeArray = probModel.getTransPerAction();
        if (jDDNodeArray != null) {
            int n3 = jDDNodeArray.length;
            for (int i = 0; i < n3; ++i) {
                Object object = list.get(i);
                this.traverseMatrixDD(jDDNodeArray[i], probModel.getAllDDRowVars(), probModel.getAllDDColVars(), probModel.getODD(), (n, n2, d, object2) -> dTMCSimple.addToProbability(n, n2, (Double)d, object));
            }
        } else {
            this.traverseMatrixDD(probModel.getTrans(), probModel.getAllDDRowVars(), probModel.getAllDDColVars(), probModel.getODD(), dTMCSimple::addToProbability);
        }
    }

    private void convertMDPTransitions(NondetModel nondetModel, MDPSimple<Double> mDPSimple) throws PrismException {
        List<Object> list = nondetModel.getActions();
        List<Pair<JDDNode, JDDNode>> list2 = this.splitMDPDD(nondetModel.getTrans(), nondetModel.getTransActions(), nondetModel.getAllDDNondetVars());
        for (Pair<JDDNode, JDDNode> pair : list2) {
            JDDNode jDDNode = pair.getKey();
            HashMap<Integer, Distribution> hashMap = new HashMap<Integer, Distribution>();
            this.traverseMatrixDD(jDDNode, nondetModel.getAllDDRowVars(), nondetModel.getAllDDColVars(), nondetModel.getODD(), (n, n2, d, object) -> {
                Distribution<Double> distribution = (Distribution<Double>)hashMap.get(n);
                if (distribution == null) {
                    distribution = Distribution.ofDouble();
                    hashMap.put(n, distribution);
                }
                distribution.add(n2, (Double)d);
            });
            JDDNode jDDNode2 = pair.getValue();
            HashMap hashMap2 = new HashMap();
            this.traverseVectorDD(jDDNode2, nondetModel.getAllDDRowVars(), nondetModel.getODD(), (n, d) -> {
                Object e;
                int n2 = (int)Math.round(d);
                if (n2 < list.size() && (e = list.get(n2)) != null) {
                    hashMap2.put(n, e);
                }
            });
            hashMap.forEach((n, distribution) -> mDPSimple.addActionLabelledChoice((int)n, (Distribution<Double>)distribution, hashMap2.get(n)));
            JDD.Deref(jDDNode);
            JDD.Deref(jDDNode2);
        }
    }

    private void convertMarkovChainTransitionRewards(ProbModel probModel, DTMC<Double> dTMC, JDDNode jDDNode, RewardsExplicit<Double> rewardsExplicit) throws PrismException {
        this.traverseMatrixDD(jDDNode, probModel.getAllDDRowVars(), probModel.getAllDDColVars(), probModel.getODD(), (n, n2, d, object) -> {
            SuccessorsIterator successorsIterator = dTMC.getSuccessors(n);
            int n3 = 0;
            while (successorsIterator.hasNext()) {
                if (successorsIterator.nextInt() == n2) {
                    rewardsExplicit.setTransitionReward(n, n3, (Double)d);
                    return;
                }
                ++n3;
            }
        });
    }

    private void convertMDPTransitionRewards(NondetModel nondetModel, MDP<Double> mDP, JDDNode jDDNode, RewardsExplicit<Double> rewardsExplicit) throws PrismException {
        int n2 = nondetModel.getNumStates();
        int[] nArray = new int[n2];
        List<Pair<JDDNode, JDDNode>> list = this.splitMDPDD(nondetModel.getTrans01(), jDDNode, nondetModel.getAllDDNondetVars());
        for (Pair<JDDNode, JDDNode> pair : list) {
            JDDNode jDDNode2 = JDD.ThereExists(pair.getKey(), nondetModel.getAllDDColVars());
            this.traverseStatesBDD(jDDNode2, nondetModel.getAllDDRowVars(), nondetModel.getODD(), n -> {
                int n2 = n;
                nArray[n2] = nArray[n2] + 1;
            });
            JDDNode jDDNode3 = JDD.MaxAbstract(pair.getValue(), nondetModel.getAllDDColVars());
            this.traverseVectorDD(jDDNode3, nondetModel.getAllDDRowVars(), nondetModel.getODD(), (n, d) -> rewardsExplicit.setTransitionReward(n, nArray[n] - 1, (Double)d));
            JDD.Deref(jDDNode2);
            JDD.Deref(jDDNode3);
        }
    }

    private List<Pair<JDDNode, JDDNode>> splitMDPDD(JDDNode jDDNode, JDDNode jDDNode2, JDDVars jDDVars) {
        ArrayList<Pair<JDDNode, JDDNode>> arrayList = new ArrayList<Pair<JDDNode, JDDNode>>();
        this.splitMDPDDRec(jDDNode, jDDNode2, jDDVars, 0, arrayList);
        return arrayList;
    }

    private void splitMDPDDRec(JDDNode jDDNode, JDDNode jDDNode2, JDDVars jDDVars, int n, List<Pair<JDDNode, JDDNode>> list) {
        JDDNode jDDNode3;
        JDDNode jDDNode4;
        JDDNode jDDNode5;
        JDDNode jDDNode6;
        if (jDDNode.equals(JDD.ZERO)) {
            return;
        }
        if (n == jDDVars.n()) {
            list.add(new Pair<JDDNode, JDDNode>(jDDNode.copy(), jDDNode2.copy()));
            return;
        }
        if (jDDNode.getIndex() > jDDVars.getVarIndex(n)) {
            jDDNode5 = jDDNode6 = jDDNode;
        } else {
            jDDNode5 = jDDNode.getElse();
            jDDNode6 = jDDNode.getThen();
        }
        if (jDDNode2.getIndex() > jDDVars.getVarIndex(n)) {
            jDDNode3 = jDDNode4 = jDDNode2;
        } else {
            jDDNode3 = jDDNode2.getElse();
            jDDNode4 = jDDNode2.getThen();
        }
        this.splitMDPDDRec(jDDNode5, jDDNode3, jDDVars, n + 1, list);
        this.splitMDPDDRec(jDDNode6, jDDNode4, jDDVars, n + 1, list);
    }

    private void traverseStatesBDD(JDDNode jDDNode, JDDVars jDDVars, ODDNode oDDNode, IntConsumer intConsumer) throws PrismException {
        this.traverseStatesBDDRec(jDDNode, jDDVars, 0, oDDNode, 0L, intConsumer);
    }

    private void traverseStatesBDDRec(JDDNode jDDNode, JDDVars jDDVars, int n, ODDNode oDDNode, long l, IntConsumer intConsumer) throws PrismException {
        JDDNode jDDNode2;
        JDDNode jDDNode3;
        if (jDDNode.equals(JDD.ZERO)) {
            return;
        }
        if (n == jDDVars.n()) {
            intConsumer.accept(SafeCast.toInt(l));
            return;
        }
        if (jDDNode.getIndex() > jDDVars.getVarIndex(n)) {
            jDDNode2 = jDDNode3 = jDDNode;
        } else {
            jDDNode2 = jDDNode.getElse();
            jDDNode3 = jDDNode.getThen();
        }
        this.traverseStatesBDDRec(jDDNode2, jDDVars, n + 1, oDDNode.getElse(), l, intConsumer);
        this.traverseStatesBDDRec(jDDNode3, jDDVars, n + 1, oDDNode.getThen(), l + oDDNode.getEOff(), intConsumer);
    }

    private void traverseVectorDD(JDDNode jDDNode, JDDVars jDDVars, ODDNode oDDNode, IOUtils.StateValueConsumer<Double> stateValueConsumer) throws PrismException {
        this.traverseVectorDDRec(jDDNode, jDDVars, 0, oDDNode, 0L, stateValueConsumer);
    }

    private void traverseVectorDDRec(JDDNode jDDNode, JDDVars jDDVars, int n, ODDNode oDDNode, long l, IOUtils.StateValueConsumer<Double> stateValueConsumer) throws PrismException {
        JDDNode jDDNode2;
        JDDNode jDDNode3;
        if (jDDNode.equals(JDD.ZERO)) {
            return;
        }
        if (n == jDDVars.n()) {
            stateValueConsumer.accept(SafeCast.toInt(l), jDDNode.getValue());
            return;
        }
        if (jDDNode.getIndex() > jDDVars.getVarIndex(n)) {
            jDDNode2 = jDDNode3 = jDDNode;
        } else {
            jDDNode2 = jDDNode.getElse();
            jDDNode3 = jDDNode.getThen();
        }
        this.traverseVectorDDRec(jDDNode2, jDDVars, n + 1, oDDNode.getElse(), l, stateValueConsumer);
        this.traverseVectorDDRec(jDDNode3, jDDVars, n + 1, oDDNode.getThen(), l + oDDNode.getEOff(), stateValueConsumer);
    }

    private void traverseMatrixDD(JDDNode jDDNode, JDDVars jDDVars, JDDVars jDDVars2, ODDNode oDDNode, IOUtils.MCTransitionConsumer<Double> mCTransitionConsumer) throws PrismException {
        this.traverseMatrixDD(jDDNode, jDDVars, jDDVars2, 0, oDDNode, oDDNode, 0L, 0L, mCTransitionConsumer);
    }

    private void traverseMatrixDD(JDDNode jDDNode, JDDVars jDDVars, JDDVars jDDVars2, int n, ODDNode oDDNode, ODDNode oDDNode2, long l, long l2, IOUtils.MCTransitionConsumer<Double> mCTransitionConsumer) throws PrismException {
        JDDNode jDDNode2;
        JDDNode jDDNode3;
        JDDNode jDDNode4;
        JDDNode jDDNode5;
        if (jDDNode.equals(JDD.ZERO)) {
            return;
        }
        if (n == jDDVars.n()) {
            mCTransitionConsumer.accept(SafeCast.toInt(l), SafeCast.toInt(l2), jDDNode.getValue(), null);
            return;
        }
        if (jDDNode.getIndex() > jDDVars2.getVarIndex(n)) {
            jDDNode4 = jDDNode5 = jDDNode;
            jDDNode3 = jDDNode5;
            jDDNode2 = jDDNode5;
        } else if (jDDNode.getIndex() > jDDVars.getVarIndex(n)) {
            jDDNode2 = jDDNode4 = jDDNode.getElse();
            jDDNode3 = jDDNode5 = jDDNode.getThen();
        } else {
            JDDNode jDDNode6 = jDDNode.getElse();
            if (jDDNode6.getIndex() > jDDVars2.getVarIndex(n)) {
                jDDNode2 = jDDNode3 = jDDNode6;
            } else {
                jDDNode2 = jDDNode6.getElse();
                jDDNode3 = jDDNode6.getThen();
            }
            JDDNode jDDNode7 = jDDNode.getThen();
            if (jDDNode7.getIndex() > jDDVars2.getVarIndex(n)) {
                jDDNode4 = jDDNode5 = jDDNode7;
            } else {
                jDDNode4 = jDDNode7.getElse();
                jDDNode5 = jDDNode7.getThen();
            }
        }
        this.traverseMatrixDD(jDDNode2, jDDVars, jDDVars2, n + 1, oDDNode.getElse(), oDDNode2.getElse(), l, l2, mCTransitionConsumer);
        this.traverseMatrixDD(jDDNode3, jDDVars, jDDVars2, n + 1, oDDNode.getElse(), oDDNode2.getThen(), l, l2 + oDDNode2.getEOff(), mCTransitionConsumer);
        this.traverseMatrixDD(jDDNode4, jDDVars, jDDVars2, n + 1, oDDNode.getThen(), oDDNode2.getElse(), l + oDDNode.getEOff(), l2, mCTransitionConsumer);
        this.traverseMatrixDD(jDDNode5, jDDVars, jDDVars2, n + 1, oDDNode.getThen(), oDDNode2.getThen(), l + oDDNode.getEOff(), l2 + oDDNode2.getEOff(), mCTransitionConsumer);
    }
}

