package aima.core.learning.reinforcement.agent;

import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.probability.mdp.PolicyEvaluation;
import aima.core.probability.mdp.RewardFunction;
import aima.core.probability.mdp.TransitionProbabilityFunction;
import aima.core.probability.mdp.impl.MDP;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:lib/aima-core-3.0.0.jar:aima/core/learning/reinforcement/agent/PassiveADPAgent.class */
public class PassiveADPAgent<S, A extends Action> extends ReinforcementAgent<S, A> {
    private MDP<S, A> mdp;
    private PolicyEvaluation<S, A> policyEvaluation;
    private Map<S, A> pi = new HashMap();
    private Map<Pair<S, Pair<S, A>>, Double> P = new HashMap();
    private Map<S, Double> R = new HashMap();
    private Map<S, Double> U = new HashMap();
    private FrequencyCounter<Pair<S, A>> Nsa = new FrequencyCounter<>();
    private FrequencyCounter<Pair<S, Pair<S, A>>> NsDelta_sa = new FrequencyCounter<>();
    private S s = null;
    private A a = null;

    public PassiveADPAgent(Map<S, A> map, Set<S> set, S s, ActionsFunction<S, A> actionsFunction, PolicyEvaluation<S, A> policyEvaluation) {
        this.mdp = null;
        this.policyEvaluation = null;
        this.pi.putAll(map);
        this.mdp = new MDP<>(set, s, actionsFunction, new TransitionProbabilityFunction<S, A>() { // from class: aima.core.learning.reinforcement.agent.PassiveADPAgent.1
            @Override // aima.core.probability.mdp.TransitionProbabilityFunction
            public double probability(S s2, S s3, A a) {
                Double d = (Double) PassiveADPAgent.this.P.get(new Pair(s2, new Pair(s3, a)));
                if (null == d) {
                    return 0.0d;
                }
                return d.doubleValue();
            }
        }, new RewardFunction<S>() { // from class: aima.core.learning.reinforcement.agent.PassiveADPAgent.2
            @Override // aima.core.probability.mdp.RewardFunction
            public double reward(S s2) {
                return ((Double) PassiveADPAgent.this.R.get(s2)).doubleValue();
            }
        });
        this.policyEvaluation = policyEvaluation;
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public A execute(PerceptStateReward<S> perceptStateReward) {
        S state = perceptStateReward.state();
        double reward = perceptStateReward.reward();
        if (!this.U.containsKey(state)) {
            this.U.put(state, Double.valueOf(reward));
            this.R.put(state, Double.valueOf(reward));
        }
        if (null != this.s) {
            Pair<S, A> pair = new Pair<>(this.s, this.a);
            this.Nsa.incrementFor(pair);
            this.NsDelta_sa.incrementFor(new Pair<>(state, pair));
            Iterator<S> it = this.mdp.states().iterator();
            while (it.hasNext()) {
                Pair<S, Pair<S, A>> pair2 = new Pair<>(it.next(), pair);
                if (0 != this.NsDelta_sa.getCount(pair2).intValue()) {
                    this.P.put(pair2, Double.valueOf(this.NsDelta_sa.getCount(pair2).doubleValue() / this.Nsa.getCount(pair).doubleValue()));
                }
            }
        }
        this.U = this.policyEvaluation.evaluate(this.pi, this.U, this.mdp);
        if (isTerminal(state)) {
            this.s = null;
            this.a = null;
        } else {
            this.s = state;
            this.a = this.pi.get(state);
        }
        return this.a;
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public Map<S, Double> getUtility() {
        return Collections.unmodifiableMap(this.U);
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public void reset() {
        this.P.clear();
        this.R.clear();
        this.U = new HashMap();
        this.Nsa.clear();
        this.NsDelta_sa.clear();
        this.s = null;
        this.a = null;
    }

    private boolean isTerminal(S s) {
        boolean z = false;
        if (0 == this.mdp.actions(s).size()) {
            z = true;
        }
        return z;
    }
}
