/*
 * Decompiled with CFR 0.152.
 */
package com.mentalfrostbyte.jello.util.game.player.rotation;

import com.mentalfrostbyte.Client;
import com.mentalfrostbyte.jello.util.game.player.rotation.JelloAI;
import com.mentalfrostbyte.jello.util.game.player.rotation.NeuralNetwork;
import com.mentalfrostbyte.jello.util.game.player.rotation.TrainingManager;
import java.util.HashMap;
import java.util.Map;
import net.minecraft.client.Minecraft;
import net.minecraft.entity.Entity;

public class ReinforcementManager {
    private static final Minecraft mc = Minecraft.getInstance();
    private static final float HIT_REWARD = 1.0f;
    private static final float MOVING_HIT_REWARD = 1.5f;
    private static final float MISS_PENALTY = -0.2f;
    private Map<Entity, Long> lastHitTimes = new HashMap<Entity, Long>();
    private Map<Entity, Float> entityRewards = new HashMap<Entity, Float>();
    private final NeuralNetwork neuralNetwork;
    private final TrainingManager trainingManager;

    public ReinforcementManager(NeuralNetwork sharedNetwork, TrainingManager sharedTrainer) {
        this.neuralNetwork = sharedNetwork;
        this.trainingManager = sharedTrainer;
    }

    public ReinforcementManager() {
        this.neuralNetwork = new NeuralNetwork();
        this.trainingManager = new TrainingManager();
    }

    public void recordHit(Entity entity, boolean wasMoving, float currentYaw, float currentPitch) {
        if (entity == null || ReinforcementManager.mc.player == null) {
            return;
        }
        float reward = wasMoving ? 1.5f : 1.0f;
        this.lastHitTimes.put(entity, System.currentTimeMillis());
        this.entityRewards.put(entity, Float.valueOf(this.entityRewards.getOrDefault(entity, Float.valueOf(0.0f)).floatValue() + reward));
        float[] inputs = this.getEntityInputs(entity);
        float[] normalizedOutputs = JelloAI.normalizeRotations(currentYaw, currentPitch);
        for (int i = 0; i < 3; ++i) {
            this.trainingManager.addTrainingSample(inputs, normalizedOutputs, reward);
        }
        this.neuralNetwork.trainNetworkImmediate(inputs, normalizedOutputs, reward);
        Client.logger.info("JelloAI: Recorded hit on " + entity.getName().getString() + " with reward " + reward + " (Total: " + String.valueOf(this.entityRewards.get(entity)) + ")");
    }

    public void recordMiss(Entity entity, float[] inputs, float[] idealOutputs) {
        if (entity == null || ReinforcementManager.mc.player == null) {
            return;
        }
        float penalty = -0.2f;
        this.trainingManager.addTrainingSample(inputs, idealOutputs, penalty);
        Client.logger.info("JelloAI: Recorded miss with penalty " + penalty);
    }

    private float[] getEntityInputs(Entity entity) {
        if (entity == null || ReinforcementManager.mc.player == null) {
            return new float[8];
        }
        float[] inputs = new float[8];
        double playerX = ReinforcementManager.mc.player.getPosX();
        double playerY = ReinforcementManager.mc.player.getPosY() + (double)ReinforcementManager.mc.player.getEyeHeight();
        double playerZ = ReinforcementManager.mc.player.getPosZ();
        double entityX = entity.getPosX();
        double entityY = entity.getPosY() + (double)entity.getEyeHeight();
        double entityZ = entity.getPosZ();
        double diffX = entityX - playerX;
        double diffY = entityY - playerY;
        double diffZ = entityZ - playerZ;
        inputs[0] = (float)(diffX / 20.0);
        inputs[1] = (float)(diffY / 10.0);
        inputs[2] = (float)(diffZ / 20.0);
        inputs[3] = (float)(entity.getMotion().x / 2.0);
        inputs[4] = (float)(entity.getMotion().y / 2.0);
        inputs[5] = (float)(entity.getMotion().z / 2.0);
        inputs[6] = ReinforcementManager.mc.player.rotationYaw / 180.0f;
        inputs[7] = ReinforcementManager.mc.player.rotationPitch / 90.0f;
        return inputs;
    }
}

