/*
 * Decompiled with CFR 0.152.
 */
package explicit;

import explicit.Belief;
import explicit.Distribution;
import explicit.MDPSimple;
import explicit.POMDP;
import explicit.rewards.MDPRewards;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import parser.State;
import prism.PrismException;
import prism.PrismUtils;

public class POMDPSimple<Value>
extends MDPSimple<Value>
implements POMDP<Value> {
    protected List<State> observationsList;
    protected List<State> unobservationsList;
    protected List<Integer> observationStates;
    protected List<Integer> observablesMap;
    protected List<Integer> unobservablesMap;

    public POMDPSimple() {
        this.initialiseObservables();
    }

    public POMDPSimple(int n) {
        super(n);
        this.initialiseObservables();
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple) {
        super(pOMDPSimple);
        this.observationsList = new ArrayList<State>(pOMDPSimple.observationsList);
        this.unobservationsList = new ArrayList<State>(pOMDPSimple.unobservationsList);
        this.observationStates = new ArrayList<Integer>(pOMDPSimple.observationStates);
        this.observablesMap = new ArrayList<Integer>(pOMDPSimple.observablesMap);
        this.unobservablesMap = new ArrayList<Integer>(pOMDPSimple.unobservablesMap);
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple, int[] nArray) {
        super(pOMDPSimple, nArray);
        int n;
        this.observationsList = new ArrayList<State>(pOMDPSimple.observationsList);
        this.unobservationsList = new ArrayList<State>(pOMDPSimple.unobservationsList);
        int n2 = pOMDPSimple.getNumObservations();
        this.observationStates = new ArrayList<Integer>(n2);
        for (n = 0; n < n2; ++n) {
            int n3 = pOMDPSimple.observationStates.get(n);
            this.observationStates.add(n3 == -1 ? -1 : nArray[n3]);
        }
        this.observablesMap = new ArrayList<Integer>(this.getNumStates());
        this.unobservablesMap = new ArrayList<Integer>(this.getNumStates());
        for (n = 0; n < this.numStates; ++n) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
        for (n = 0; n < this.numStates; ++n) {
            this.observablesMap.set(nArray[n], pOMDPSimple.observablesMap.get(n));
            this.unobservablesMap.set(nArray[n], pOMDPSimple.unobservablesMap.get(n));
        }
    }

    public POMDPSimple(MDPSimple<Value> mDPSimple) {
        super(mDPSimple);
        this.initialiseObservables(mDPSimple.numStates);
        for (int i = 0; i < this.numStates; ++i) {
            try {
                this.setObservation(i, i);
            }
            catch (PrismException prismException) {
                // empty catch block
            }
            this.unobservablesMap.set(i, null);
        }
    }

    protected void initialiseObservables() {
        this.observationsList = new ArrayList<State>();
        this.unobservationsList = new ArrayList<State>();
        this.observationStates = new ArrayList<Integer>();
        this.observablesMap = new ArrayList<Integer>();
        this.unobservablesMap = new ArrayList<Integer>();
    }

    protected void initialiseObservables(int n) {
        this.observationsList = new ArrayList<State>();
        this.unobservationsList = new ArrayList<State>();
        this.observationStates = new ArrayList<Integer>();
        this.observablesMap = new ArrayList<Integer>(n);
        this.unobservablesMap = new ArrayList<Integer>(n);
        for (int i = 0; i < n; ++i) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
    }

    @Override
    public void clearState(int n) {
        super.clearState(n);
        this.observablesMap.set(n, -1);
        this.unobservablesMap.set(n, -1);
    }

    @Override
    public void addStates(int n) {
        super.addStates(n);
        for (int i = 0; i < n; ++i) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
    }

    public void setObservationsList(List<State> list) {
        this.observationsList = list;
    }

    public void setUnobservationsList(List<State> list) {
        this.unobservationsList = list;
    }

    public void setObservation(int n, State state, State state2, List<String> list) throws PrismException {
        int n2 = this.observationsList.indexOf(state);
        if (n2 == -1) {
            this.observationsList.add(state);
            n2 = this.observationsList.size() - 1;
            this.observationStates.add(-1);
        }
        try {
            this.setObservation(n, n2);
        }
        catch (PrismException prismException) {
            String string = list == null ? state.toString() : state.toString(list);
            throw new PrismException("Problem with observation " + string + ": " + prismException.getMessage());
        }
        int n3 = this.unobservationsList.indexOf(state2);
        if (n3 == -1) {
            this.unobservationsList.add(state2);
            n3 = this.unobservationsList.size() - 1;
        }
        this.unobservablesMap.set(n, n3);
    }

    protected void setObservation(int n, int n2) throws PrismException {
        this.observablesMap.set(n, n2);
        int n3 = this.observationStates.get(n2);
        if (n3 == -1) {
            this.observationStates.set(n2, n);
        } else {
            this.checkActionsMatchExactly(n, n3);
        }
    }

    protected void checkActionsMatchExactly(int n, int n2) throws PrismException {
        int n3 = this.getNumChoices(n);
        if (n3 != this.getNumChoices(n2)) {
            throw new PrismException("Differing actions found in states: " + this.getAvailableActions(n) + " vs. " + this.getAvailableActions(n2));
        }
        for (int i = 0; i < n3; ++i) {
            Object object = this.getAction(n, i);
            Object object2 = this.getAction(n2, i);
            if (!(object == null ? object2 != null : !object.equals(object2))) continue;
            throw new PrismException("Differing actions found in states: " + this.getAvailableActions(n) + " vs. " + this.getAvailableActions(n2));
        }
    }

    protected void checkActionsMatch(int n, int n2) throws PrismException {
        ArrayList<String> arrayList = new ArrayList<String>();
        int n3 = this.getNumChoices(n);
        for (int i = 0; i < n3; ++i) {
            Object object = this.getAction(n, i);
            arrayList.add(object == null ? "" : object.toString());
        }
        Collections.sort(arrayList);
        ArrayList<String> arrayList2 = new ArrayList<String>();
        n3 = this.getNumChoices(n2);
        for (int i = 0; i < n3; ++i) {
            Object object = this.getAction(n2, i);
            arrayList2.add(object == null ? "" : object.toString());
        }
        Collections.sort(arrayList2);
        if (!arrayList.equals(arrayList2)) {
            throw new PrismException("Differing actions found in states: " + arrayList + " vs. " + arrayList2);
        }
    }

    @Override
    public List<State> getObservationsList() {
        return this.observationsList;
    }

    @Override
    public List<State> getUnobservationsList() {
        return this.unobservationsList;
    }

    @Override
    public int getObservation(int n) {
        return this.observablesMap == null ? -1 : this.observablesMap.get(n);
    }

    @Override
    public int getUnobservation(int n) {
        return this.unobservablesMap.get(n);
    }

    @Override
    public int getNumChoicesForObservation(int n) {
        return this.getNumChoices(this.observationStates.get(n));
    }

    @Override
    public Object getActionForObservation(int n, int n2) {
        return this.getAction(this.observationStates.get(n), n2);
    }

    @Override
    public Belief getInitialBelief() {
        double[] dArray = new double[this.numStates];
        for (Integer n : this.initialStates) {
            dArray[n.intValue()] = 1.0;
        }
        PrismUtils.normalise(dArray);
        return new Belief(dArray, this);
    }

    @Override
    public double[] getInitialBeliefInDist() {
        double[] dArray = new double[this.numStates];
        for (Integer n : this.initialStates) {
            dArray[n.intValue()] = 1.0;
        }
        PrismUtils.normalise(dArray);
        return dArray;
    }

    @Override
    public Belief getBeliefAfterChoice(Belief belief, int n) {
        double[] dArray = belief.toDistributionOverStates(this);
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        return new Belief(dArray2, this);
    }

    @Override
    public double[] getBeliefInDistAfterChoice(double[] dArray, int n) {
        int n2 = dArray.length;
        double[] dArray2 = new double[n2];
        for (int i = 0; i < n2; ++i) {
            if (!(dArray[i] >= 1.0E-6)) continue;
            Distribution distribution = this.getChoice(i, n);
            for (Map.Entry entry : distribution) {
                int n3 = (Integer)entry.getKey();
                double d = (Double)entry.getValue();
                int n4 = n3;
                dArray2[n4] = dArray2[n4] + dArray[i] * d;
            }
        }
        return dArray2;
    }

    @Override
    public Belief getBeliefAfterChoiceAndObservation(Belief belief, int n, int n2) {
        double[] dArray = belief.toDistributionOverStates(this);
        double[] dArray2 = this.getBeliefInDistAfterChoiceAndObservation(dArray, n, n2);
        Belief belief2 = new Belief(dArray2, this);
        assert (belief2.so == n2);
        return belief2;
    }

    @Override
    public double[] getBeliefInDistAfterChoiceAndObservation(double[] dArray, int n, int n2) {
        int n3 = dArray.length;
        double[] dArray2 = new double[n3];
        double[] dArray3 = this.getBeliefInDistAfterChoice(dArray, n);
        for (int i = 0; i < n3; ++i) {
            double d;
            dArray2[i] = d = dArray3[i] * this.getObservationProb(i, n2);
        }
        PrismUtils.normalise(dArray2);
        return dArray2;
    }

    @Override
    public double getObservationProbAfterChoice(Belief belief, int n, int n2) {
        double[] dArray = belief.toDistributionOverStates(this);
        double d = this.getObservationProbAfterChoice(dArray, n, n2);
        return d;
    }

    @Override
    public double getObservationProbAfterChoice(double[] dArray, int n, int n2) {
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        double d = 0.0;
        for (int i = 0; i < dArray2.length; ++i) {
            d += dArray2[i] * this.getObservationProb(i, n2);
        }
        return d;
    }

    @Override
    public HashMap<Integer, Double> computeObservationProbsAfterAction(double[] dArray, int n) {
        HashMap<Integer, Double> hashMap = new HashMap<Integer, Double>();
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        for (int i = 0; i < dArray2.length; ++i) {
            int n2 = this.getObservation(i);
            double d = dArray2[i];
            if (!(d > 1.0E-6)) continue;
            Double d2 = hashMap.get(n2);
            if (d2 == null) {
                hashMap.put(n2, d);
                continue;
            }
            hashMap.put(n2, d2 + d);
        }
        return hashMap;
    }

    @Override
    public double getRewardAfterChoice(Belief belief, int n, MDPRewards<Double> mDPRewards) {
        double[] dArray = belief.toDistributionOverStates(this);
        double d = this.getRewardAfterChoice(dArray, n, mDPRewards);
        return d;
    }

    @Override
    public double getRewardAfterChoice(double[] dArray, int n, MDPRewards<Double> mDPRewards) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == 0.0) {
                d += 0.0;
                continue;
            }
            d += dArray[i] * (mDPRewards.getTransitionReward(i, n) + mDPRewards.getStateReward(i));
        }
        return d;
    }

    protected Belief beliefInDistToBelief(double[] dArray) {
        int n = -1;
        double[] dArray2 = new double[this.getNumUnobservations()];
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == 0.0) continue;
            n = this.getObservation(i);
            int n2 = this.getUnobservation(i);
            dArray2[n2] = dArray2[n2] + dArray[i];
        }
        Belief belief = null;
        if (n != -1) {
            belief = new Belief(n, dArray2);
        } else {
            System.err.println("Something wrong in POMDPSimple.beliefInDistToBelief(double[] beliefInDist)");
        }
        return belief;
    }

    @Override
    public String toString() {
        Object object = "";
        object = "[ ";
        for (int i = 0; i < this.numStates; ++i) {
            if (i > 0) {
                object = (String)object + ", ";
            }
            object = (String)object + i + "(" + this.getObservation(i) + "/" + this.getUnobservation(i) + "): ";
            object = (String)object + "[";
            int n = this.getNumChoices(i);
            for (int j = 0; j < n; ++j) {
                Object object2;
                if (j > 0) {
                    object = (String)object + ",";
                }
                if ((object2 = this.getAction(i, j)) != null) {
                    object = (String)object + object2 + ":";
                }
                object = (String)object + ((List)this.trans.get(i)).get(j);
            }
            object = (String)object + "]";
        }
        object = (String)object + " ]\n";
        return object;
    }

    @Override
    public boolean equals(Object object) {
        if (object == null || !(object instanceof POMDPSimple)) {
            return false;
        }
        POMDPSimple pOMDPSimple = (POMDPSimple)object;
        if (this.numStates != pOMDPSimple.numStates) {
            return false;
        }
        if (!this.initialStates.equals(pOMDPSimple.initialStates)) {
            return false;
        }
        return this.trans.equals(pOMDPSimple.trans);
    }
}

