/*
 * Decompiled with CFR 0.152.
 */
package ghidra.pcode.emu.jit.analysis;

import ghidra.app.plugin.processors.sleigh.SleighLanguage;
import ghidra.pcode.emu.jit.alloc.AlignedMpIntHandler;
import ghidra.pcode.emu.jit.alloc.DoubleVarAlloc;
import ghidra.pcode.emu.jit.alloc.FloatVarAlloc;
import ghidra.pcode.emu.jit.alloc.IntInIntHandler;
import ghidra.pcode.emu.jit.alloc.IntInLongHandler;
import ghidra.pcode.emu.jit.alloc.IntVarAlloc;
import ghidra.pcode.emu.jit.alloc.JvmLocal;
import ghidra.pcode.emu.jit.alloc.LongInLongHandler;
import ghidra.pcode.emu.jit.alloc.LongVarAlloc;
import ghidra.pcode.emu.jit.alloc.NoHandler;
import ghidra.pcode.emu.jit.alloc.ShiftedMpIntHandler;
import ghidra.pcode.emu.jit.alloc.SimpleVarHandler;
import ghidra.pcode.emu.jit.alloc.VarHandler;
import ghidra.pcode.emu.jit.analysis.JitAnalysisContext;
import ghidra.pcode.emu.jit.analysis.JitDataFlowModel;
import ghidra.pcode.emu.jit.analysis.JitType;
import ghidra.pcode.emu.jit.analysis.JitTypeBehavior;
import ghidra.pcode.emu.jit.analysis.JitTypeModel;
import ghidra.pcode.emu.jit.analysis.JitVarScopeModel;
import ghidra.pcode.emu.jit.gen.util.Local;
import ghidra.pcode.emu.jit.gen.util.Scope;
import ghidra.pcode.emu.jit.gen.util.Types;
import ghidra.pcode.emu.jit.var.JitConstVal;
import ghidra.pcode.emu.jit.var.JitFailVal;
import ghidra.pcode.emu.jit.var.JitMemoryVar;
import ghidra.pcode.emu.jit.var.JitVal;
import ghidra.pcode.emu.jit.var.JitVarnodeVar;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressFactory;
import ghidra.program.model.address.AddressSpace;
import ghidra.program.model.lang.Endian;
import ghidra.program.model.lang.Language;
import ghidra.program.model.lang.Register;
import ghidra.program.model.pcode.Varnode;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.TreeMap;

public class JitAllocationModel {
    private final JitDataFlowModel dfm;
    private final JitVarScopeModel vsm;
    private final JitTypeModel tm;
    private final SleighLanguage language;
    private final Endian endian;
    private final Map<JitVal, VarHandler> handlers = new HashMap<JitVal, VarHandler>();
    private final Map<Varnode, VarHandler> handlersPerVarnode = new HashMap<Varnode, VarHandler>();
    private final NavigableMap<Address, JvmLocal<?, ?>> locals = new TreeMap();
    private final Map<Varnode, TypeContest> typeContests = new HashMap<Varnode, TypeContest>();

    public JitAllocationModel(JitAnalysisContext context, JitDataFlowModel dfm, JitVarScopeModel vsm, JitTypeModel tm) {
        this.dfm = dfm;
        this.vsm = vsm;
        this.tm = tm;
        this.endian = context.getEndian();
        this.language = context.getLanguage();
        this.analyze();
    }

    private <T extends Types.BPrim<?>, JT extends JitType.SimpleJitType<T, JT>> JvmLocal<T, JT> declareLocal(Scope scope, JT type, String name, VarDesc desc) {
        Local<T> local = scope.decl(type.bType(), name);
        return JvmLocal.of(local, type, desc.toVarnode());
    }

    private <T extends Types.BPrim<?>, JT extends JitType.SimpleJitType<T, JT>> List<JvmLocal<T, JT>> declareLocals(Scope scope, List<JT> types, String name, VarDesc desc) {
        JvmLocal[] result = new JvmLocal[types.size()];
        long offset = desc.offset;
        for (int i = 0; i < types.size(); ++i) {
            JitType.SimpleJitType t = (JitType.SimpleJitType)types.get(i);
            VarDesc d = new VarDesc(desc.spaceId, offset, t.size(), t, (Language)this.language);
            result[i] = this.declareLocal(scope, t, name + "_" + i, d);
            offset += (long)t.size();
        }
        return List.of(result);
    }

    private <T extends Types.BPrim<?>, JT extends JitType.SimpleJitType<T, JT>> SimpleVarHandler<T, JT> createSimpleHandler(JvmLocal<T, JT> local) {
        JT JT = local.type();
        Objects.requireNonNull(JT);
        JT JT2 = JT;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitType.IntJitType.class, JitType.LongJitType.class, JitType.FloatJitType.class, JitType.DoubleJitType.class}, JT2, n)) {
            case 0 -> {
                JitType.IntJitType t = (JitType.IntJitType)JT2;
                yield new IntVarAlloc(local.castOf(t), t);
            }
            case 1 -> {
                JitType.LongJitType t = (JitType.LongJitType)JT2;
                yield new LongVarAlloc(local.castOf(t), t);
            }
            case 2 -> {
                JitType.FloatJitType t = (JitType.FloatJitType)JT2;
                yield new FloatVarAlloc(local.castOf(t), t);
            }
            case 3 -> {
                JitType.DoubleJitType t = (JitType.DoubleJitType)JT2;
                yield new DoubleVarAlloc(local.castOf(t), t);
            }
            default -> throw new AssertionError();
        };
    }

    private int computeByteShift(Varnode part, Varnode first, Varnode last) {
        Varnode coalesced = this.vsm.getCoalesced(part);
        if (coalesced.equals((Object)part)) {
            return 0;
        }
        return switch (this.endian) {
            default -> throw new MatchException(null, null);
            case Endian.BIG -> (int)JitVarScopeModel.maxAddr(last).subtract(JitVarScopeModel.maxAddr(part));
            case Endian.LITTLE -> (int)part.getAddress().subtract(first.getAddress());
        };
    }

    private VarHandler createComplicatedHandler(Varnode vn) {
        JitType type = JitTypeBehavior.INTEGER.type(vn.getSize());
        Map.Entry<Address, JvmLocal<?, ?>> firstEntry = this.locals.floorEntry(vn.getAddress());
        assert (JitVarScopeModel.overlapsLeft(firstEntry.getValue().vn(), vn));
        if (type instanceof JitType.SimpleJitType) {
            JitType.SimpleJitType st = (JitType.SimpleJitType)type;
            JvmLocal<?, ?> local = firstEntry.getValue();
            if (local.vn().contains(JitVarScopeModel.maxAddr(vn))) {
                int byteShift = this.computeByteShift(vn, local.vn(), local.vn());
                JitType.SimpleJitType simpleJitType = st;
                Objects.requireNonNull(simpleJitType);
                JitType.SimpleJitType simpleJitType2 = simpleJitType;
                int n = 0;
                return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitType.IntJitType.class, JitType.LongJitType.class}, (Object)simpleJitType2, n)) {
                    case 0 -> {
                        JitType.IntJitType t = (JitType.IntJitType)simpleJitType2;
                        Object v1 = local.type();
                        Objects.requireNonNull(v1);
                        Object var12_15 = v1;
                        int var13_17 = 0;
                        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitType.IntJitType.class, JitType.LongJitType.class}, var12_15, var13_17)) {
                            case 0: {
                                JitType.IntJitType ct = (JitType.IntJitType)var12_15;
                                yield new IntInIntHandler(local.castOf(ct), t, vn, byteShift);
                            }
                            case 1: {
                                JitType.LongJitType ct = (JitType.LongJitType)var12_15;
                                yield new IntInLongHandler(local.castOf(ct), t, vn, byteShift);
                            }
                        }
                        throw new AssertionError();
                    }
                    case 1 -> {
                        JitType.LongJitType t = (JitType.LongJitType)simpleJitType2;
                        Object v3 = local.type();
                        Objects.requireNonNull(v3);
                        Object var14_22 = v3;
                        int var15_23 = 0;
                        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitType.LongJitType.class}, var14_22, var15_23)) {
                            case 0: {
                                JitType.LongJitType ct = (JitType.LongJitType)var14_22;
                                yield new LongInLongHandler(local.castOf(ct), t, vn, byteShift);
                            }
                        }
                        throw new AssertionError();
                    }
                    default -> throw new AssertionError();
                };
            }
        }
        JitType.MpIntJitType mpType = JitType.MpIntJitType.forSize(type.size());
        assert (mpType.legsAlloc() > 1);
        ArrayList<JvmLocal<Types.TInt, JitType.IntJitType>> parts = new ArrayList<JvmLocal<Types.TInt, JitType.IntJitType>>();
        Address min = firstEntry.getKey();
        NavigableMap<Address, JvmLocal<?, ?>> sub = this.locals.subMap(min, true, JitVarScopeModel.maxAddr(vn), true);
        for (JvmLocal local : sub.values()) {
            assert (local.type() instanceof JitType.IntJitType);
            JvmLocal localInt = local;
            parts.add(localInt);
        }
        int byteShift = this.computeByteShift(vn, ((JvmLocal)parts.getFirst()).vn(), ((JvmLocal)parts.getLast()).vn());
        if (this.endian == Endian.BIG) {
            Collections.reverse(parts);
        }
        return byteShift == 0 ? new AlignedMpIntHandler(parts, mpType, vn) : new ShiftedMpIntHandler(parts, mpType, vn, byteShift);
    }

    private VarHandler getOrCreateHandlerForVarnodeVar(JitVarnodeVar vv) {
        return this.handlersPerVarnode.computeIfAbsent(vv.varnode(), vn -> {
            JvmLocal oneLocal = (JvmLocal)this.locals.get(vn.getAddress());
            if (oneLocal != null && oneLocal.vn().equals(vn)) {
                return this.createSimpleHandler(oneLocal);
            }
            return this.createComplicatedHandler((Varnode)vn);
        });
    }

    private VarHandler createHandler(JitVal v) {
        if (v instanceof JitConstVal) {
            return NoHandler.INSTANCE;
        }
        if (v instanceof JitFailVal) {
            return NoHandler.INSTANCE;
        }
        if (v instanceof JitMemoryVar) {
            return NoHandler.INSTANCE;
        }
        if (v instanceof JitVarnodeVar) {
            JitVarnodeVar vv = (JitVarnodeVar)v;
            return this.getOrCreateHandlerForVarnodeVar(vv);
        }
        throw new AssertionError();
    }

    private void analyze() {
        for (JitVal v : this.dfm.allValues()) {
            if (!(v instanceof JitVarnodeVar)) continue;
            JitVarnodeVar vv = (JitVarnodeVar)v;
            if (v instanceof JitMemoryVar) continue;
            Varnode vn = vv.varnode();
            Varnode coalesced = this.vsm.getCoalesced(vn);
            TypeContest tc = this.typeContests.computeIfAbsent(coalesced, __ -> new TypeContest());
            if (vn.equals((Object)coalesced)) {
                tc.vote(this.tm.typeOf(v));
                continue;
            }
            tc.vote(JitTypeBehavior.INTEGER.type(coalesced.getSize()));
        }
    }

    public void allocate(Scope scope) {
        block4: for (Map.Entry entry : this.typeContests.entrySet().stream().sorted(Comparator.comparing(e -> ((Varnode)e.getKey()).getAddress())).toList()) {
            JitType jitType;
            VarDesc desc = VarDesc.fromVarnode((Varnode)entry.getKey(), ((TypeContest)entry.getValue()).winner(), (Language)this.language);
            Objects.requireNonNull(desc.type());
            int n = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitType.SimpleJitType.class, JitType.MpIntJitType.class}, (Object)jitType, n)) {
                case 0: {
                    JitType.SimpleJitType t = (JitType.SimpleJitType)jitType;
                    JvmLocal local = this.declareLocal(scope, t, desc.name(), desc);
                    this.locals.put(((Varnode)entry.getKey()).getAddress(), local);
                    break;
                }
                case 1: {
                    JitType.MpIntJitType t = (JitType.MpIntJitType)jitType;
                    for (JvmLocal leg : this.declareLocals(scope, t.legTypesBE(), desc.name(), desc)) {
                        this.locals.put(leg.vn().getAddress(), leg);
                    }
                    continue block4;
                }
                default: {
                    throw new AssertionError();
                }
            }
        }
        for (JitVal v : this.dfm.allValuesSorted()) {
            this.handlers.put(v, this.createHandler(v));
        }
    }

    public VarHandler getHandler(JitVal v) {
        return this.handlers.get(v);
    }

    public Collection<JvmLocal<?, ?>> allLocals() {
        return this.locals.values();
    }

    public Collection<JvmLocal<?, ?>> localsForVn(Varnode vn) {
        Address min = vn.getAddress();
        Address floor = this.locals.floorKey(min);
        if (floor != null) {
            min = floor;
        }
        return this.locals.subMap(min, true, JitVarScopeModel.maxAddr(vn), true).values();
    }

    private record VarDesc(int spaceId, long offset, int size, JitType type, Language language) {
        static VarDesc fromVarnode(Varnode vn, JitType type, Language language) {
            return new VarDesc(vn.getSpace(), vn.getOffset(), vn.getSize(), type, language);
        }

        public String name() {
            AddressFactory factory = this.language.getAddressFactory();
            AddressSpace space = factory.getAddressSpace(this.spaceId);
            Register reg = this.language.getRegister(space, this.offset, this.size);
            if (reg != null) {
                return "%s_%d_%s".formatted(reg.getName(), this.size, this.type.nm());
            }
            return "s%d_%x_%d_%s".formatted(this.spaceId, this.offset, this.size, this.type.nm());
        }

        public Varnode toVarnode() {
            AddressFactory factory = this.language.getAddressFactory();
            return new Varnode(factory.getAddressSpace(this.spaceId).getAddress(this.offset), this.size);
        }
    }

    record TypeContest(Map<JitType, Integer> map) {
        public TypeContest() {
            this(new HashMap<JitType, Integer>());
        }

        public void vote(JitType type) {
            this.map.compute(type.ext(), (t, v) -> v == null ? 1 : v + 1);
        }

        public JitType winner() {
            int max = this.map.values().stream().max(Integer::compare).get();
            return this.map.entrySet().stream().filter(e -> (Integer)e.getValue() == max).map(Map.Entry::getKey).sorted(Comparator.comparing(JitType::pref)).findFirst().get();
        }
    }
}

