/*
 * Decompiled with CFR 0.152.
 */
package org.logicng.modelcounting;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.logicng.datastructures.Assignment;
import org.logicng.formulas.FType;
import org.logicng.formulas.Formula;
import org.logicng.formulas.FormulaFactory;
import org.logicng.formulas.Literal;
import org.logicng.formulas.Variable;
import org.logicng.graphs.algorithms.ConnectedComponentsComputation;
import org.logicng.graphs.datastructures.Graph;
import org.logicng.graphs.datastructures.Node;
import org.logicng.graphs.generators.ConstraintGraphGenerator;
import org.logicng.knowledgecompilation.dnnf.DnnfFactory;
import org.logicng.knowledgecompilation.dnnf.datastructures.Dnnf;
import org.logicng.knowledgecompilation.dnnf.functions.DnnfModelCountFunction;
import org.logicng.transformations.PureExpansionTransformation;
import org.logicng.transformations.cnf.CNFConfig;
import org.logicng.transformations.cnf.CNFEncoder;
import org.logicng.util.FormulaHelper;

public final class ModelCounter {
    private ModelCounter() {
    }

    public static BigInteger count(Collection<Formula> formulas, SortedSet<Variable> variables) {
        if (!variables.containsAll(FormulaHelper.variables(formulas))) {
            throw new IllegalArgumentException("Expected variables to contain all of the formulas' variables.");
        }
        if (variables.isEmpty()) {
            List remainingConstants = formulas.stream().filter(formula -> formula.type() != FType.TRUE).collect(Collectors.toList());
            return remainingConstants.isEmpty() ? BigInteger.ONE : BigInteger.ZERO;
        }
        FormulaFactory f = variables.first().factory();
        List<Formula> cnfs = ModelCounter.encodeAsCnf(formulas, f);
        SimplificationResult simplification = ModelCounter.simplify(cnfs);
        BigInteger count = ModelCounter.count((Collection<Formula>)simplification.simplifiedFormulas, f);
        SortedSet<Variable> dontCareVariables = simplification.getDontCareVariables(variables);
        return count.multiply(BigInteger.valueOf(2L).pow(dontCareVariables.size()));
    }

    private static List<Formula> encodeAsCnf(Collection<Formula> formulas, FormulaFactory f) {
        PureExpansionTransformation expander = new PureExpansionTransformation();
        List expandedFormulas = formulas.stream().map(formula -> formula.transform(expander)).collect(Collectors.toList());
        CNFEncoder cnfEncoder = new CNFEncoder(f, CNFConfig.builder().algorithm(CNFConfig.Algorithm.ADVANCED).fallbackAlgorithmForAdvancedEncoding(CNFConfig.Algorithm.TSEITIN).build());
        return expandedFormulas.stream().map(cnfEncoder::encode).collect(Collectors.toList());
    }

    private static SimplificationResult simplify(Collection<Formula> formulas) {
        Assignment simpleBackbone = new Assignment();
        TreeSet<Variable> backboneVariables = new TreeSet<Variable>();
        for (Formula formula : formulas) {
            if (formula.type() != FType.LITERAL) continue;
            Literal lit = (Literal)formula;
            simpleBackbone.addLiteral(lit);
            backboneVariables.add(lit.variable());
        }
        ArrayList<Formula> simplified = new ArrayList<Formula>();
        for (Formula formula : formulas) {
            Formula restrict = formula.restrict(simpleBackbone);
            if (restrict.type() == FType.TRUE) continue;
            simplified.add(restrict);
        }
        return new SimplificationResult(simplified, backboneVariables);
    }

    private static BigInteger count(Collection<Formula> formulas, FormulaFactory f) {
        Graph<Variable> constraintGraph = ConstraintGraphGenerator.generateFromFormulas(formulas);
        Set<Set<Node<Variable>>> ccs = ConnectedComponentsComputation.compute(constraintGraph);
        List<List<Formula>> components = ConnectedComponentsComputation.splitFormulasByComponent(formulas, ccs);
        DnnfFactory factory = new DnnfFactory();
        BigInteger count = BigInteger.ONE;
        for (List<Formula> component : components) {
            Dnnf dnnf = factory.compile(f.and(component));
            count = count.multiply(dnnf.execute(DnnfModelCountFunction.get()));
        }
        return count;
    }

    private static class SimplificationResult {
        private final List<Formula> simplifiedFormulas;
        private final SortedSet<Variable> backboneVariables;

        public SimplificationResult(List<Formula> simplifiedFormulas, SortedSet<Variable> backboneVariables) {
            this.simplifiedFormulas = simplifiedFormulas;
            this.backboneVariables = backboneVariables;
        }

        public SortedSet<Variable> getDontCareVariables(SortedSet<Variable> variables) {
            TreeSet<Variable> dontCareVariables = new TreeSet<Variable>(variables);
            dontCareVariables.removeAll(FormulaHelper.variables(this.simplifiedFormulas));
            dontCareVariables.removeAll(this.backboneVariables);
            return dontCareVariables;
        }
    }
}

