package inference;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import models.algebra.Variable;
import models.formulas.EquationFormula;
import models.formulas.Formula;
import models.formulas.meta.MetaDependencyFormula;
import models.formulas.meta.MetaEquationFormula;
import models.formulas.meta.MetaInFormula;
import models.terms.EvaluatableTerm;
import models.terms.meta.MetaEvaluatableTermVariable;
import models.terms.meta.MetaRDLTerm;
import models.terms.meta.MetaResourceVariable;
import utils.Product;

public class ProofSystem {

	private static final InferenceRule reflexivity = new InferenceRule(
			"Reflexivity",
			List.of(),
			new MetaEquationFormula(
					new MetaEvaluatableTermVariable(new Variable("te")),
					new MetaEvaluatableTermVariable(new Variable("te"))
			)
	);
	
	private static final InferenceRule symmetry = new InferenceRule(
			"Symmetry",
			List.of(
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaEvaluatableTermVariable(new Variable("se"))
					)
			),
			new MetaEquationFormula(
					new MetaEvaluatableTermVariable(new Variable("se")),
					new MetaEvaluatableTermVariable(new Variable("te"))
			)
	);
	
	private static final InferenceRule transitivity = new InferenceRule(
			"Transitivity",
			List.of(
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					),
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					)
			),
			new MetaEquationFormula(
					new MetaEvaluatableTermVariable(new Variable("se")),
					new MetaEvaluatableTermVariable(new Variable("ue"))
			)
	);
	
	private static final InferenceRule rightSubstitution = new InferenceRule(
			"Right Substitution",
			List.of(
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					),
					new MetaDependencyFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r"))
					),
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaRDLTerm(new MetaResourceVariable(new Variable("r")))
					)
			),
			new MetaEquationFormula(
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					),
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					)
			)
	);
	
	private static final InferenceRule leftSubstitution = new InferenceRule(
			"Left Substitution",
			List.of(
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					),
					new MetaDependencyFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r"))
					),
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("ue")),
							new MetaRDLTerm(new MetaResourceVariable(new Variable("r")))
					)
			),
			new MetaEquationFormula(
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					),
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaResourceVariable(new Variable("r")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					)
			)
	);
	
	private static final InferenceRule identity = new InferenceRule(
			"Identity",
			List.of(
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaRDLTerm(new MetaResourceVariable(new Variable("r")))
					)
			),
			new MetaEquationFormula(
					new MetaRDLTerm(
							new MetaResourceVariable(new Variable("r")),
							new MetaResourceVariable(new Variable("r")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					),
					new MetaEvaluatableTermVariable(new Variable("te"))
			)
	);
	
	private static final InferenceRule mapComposition = new InferenceRule(
			"Map Composition",
			List.of(
					new MetaDependencyFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r"))
					),
					new MetaDependencyFormula(
							new MetaResourceVariable(new Variable("r")),
							new MetaResourceVariable(new Variable("p"))
					),
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaRDLTerm(new MetaResourceVariable(new Variable("p")))
					)
			),
			new MetaEquationFormula(
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("r")),
							new MetaRDLTerm(
									new MetaResourceVariable(new Variable("r")),
									new MetaResourceVariable(new Variable("p")),
									new MetaEvaluatableTermVariable(new Variable("te"))
							)
					),
					new MetaRDLTerm(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaResourceVariable(new Variable("p")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					)
			)
	);
	
	private static final InferenceRule memberSubstitution = new InferenceRule(
			"Member Substitution",
			List.of(
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					),
					new MetaEquationFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					)
			),
			new MetaInFormula(
					new MetaEvaluatableTermVariable(new Variable("te")),
					new MetaEvaluatableTermVariable(new Variable("ue"))
			)
	);
	
	private static final InferenceRule membershipChain = new InferenceRule(
			"Membership Chain",
			List.of(
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("se")),
							new MetaEvaluatableTermVariable(new Variable("te"))
					),
					new MetaInFormula(
							new MetaEvaluatableTermVariable(new Variable("te")),
							new MetaEvaluatableTermVariable(new Variable("ue"))
					)
			),
			new MetaInFormula(
					new MetaEvaluatableTermVariable(new Variable("se")),
					new MetaEvaluatableTermVariable(new Variable("ue"))
			)
	);
	
	private static final InferenceRule codomainMembership = new InferenceRule(
			"Codomain Membership",
			List.of(
					new MetaDependencyFormula(
							new MetaRDLTerm(
									new MetaEvaluatableTermVariable(new Variable("se")),
									new MetaResourceVariable(new Variable("r1")),
									new MetaEvaluatableTermVariable(new Variable("t1"))
							),
							new MetaResourceVariable(new Variable("r2"))
					),
					new MetaInFormula(
							new MetaRDLTerm(
									new MetaEvaluatableTermVariable(new Variable("t1")),
									new MetaResourceVariable(new Variable("r2")),
									new MetaEvaluatableTermVariable(new Variable("t2"))
							),
							new MetaRDLTerm(
									new MetaRDLTerm(
											new MetaResourceVariable(new Variable("r1"))
									),
									new MetaResourceVariable(new Variable("r2")),
									new MetaEvaluatableTermVariable(new Variable("t2"))
							)
					)
			),
			new MetaInFormula(
					new MetaRDLTerm(
							new MetaRDLTerm(
									new MetaEvaluatableTermVariable(new Variable("se")),
									new MetaResourceVariable(new Variable("r1")),
									new MetaEvaluatableTermVariable(new Variable("t1"))
							),
							new MetaResourceVariable(new Variable("r2")),
							new MetaEvaluatableTermVariable(new Variable("t2"))
					),
					new MetaRDLTerm(
							new MetaRDLTerm(
									new MetaResourceVariable(new Variable("se"))
							),
							new MetaResourceVariable(new Variable("r2")),
							new MetaEvaluatableTermVariable(new Variable("t2"))
					)
			)
	);

	private static final List<InferenceRule> axioms = List.of(
			reflexivity, 
			symmetry, 
			transitivity,
			rightSubstitution, 
			leftSubstitution, 
			identity, 
			mapComposition, 
			memberSubstitution, 
			membershipChain, 
			codomainMembership
	);
	
	public static void debug() {
		for (var axiom : axioms) {
			System.out.println(axiom);
			System.out.println("==================================================================================================");
		}
	}
	
	
	public static boolean check(Collection<Formula> assumptions, Formula conclusion) {
//		if (equationTransitionCheck(assumptions, conclusion)) {
//			return true;
//		}
		
		Set<Formula> appearFormulas = new HashSet<>(assumptions);
		int prevAppearFormulasSize = appearFormulas.size();
		while (! appearFormulas.contains(conclusion)) {
			Set<Formula> derivedFormulas = new HashSet<>();
			System.out.println("=========================================");
			System.out.println(appearFormulas);
			for(InferenceRule axiom : axioms) {
				derivedFormulas.addAll(applyAxiom(axiom, appearFormulas));
			}
			if (derivedFormulas.size() == 0) {
				return false;
			}
			appearFormulas.addAll(derivedFormulas);
			if (appearFormulas.size() == prevAppearFormulasSize) {
				return false;
			}
			prevAppearFormulasSize = appearFormulas.size();
		}
		
		return true;
	}
	
	private static Set<Formula> applyAxiom(InferenceRule axiom, Set<Formula> formulas) {
		Set<Formula> result = new HashSet<>();
		List<List<Formula>> matchedFormulas = new ArrayList<>();
		for (int i = 0; i < axiom.getAssumptionSize(); i++) {
			matchedFormulas.add(new ArrayList<>());
			for (Formula formula : formulas) {
				if (axiom.getAssumptions().get(i).isMatchedBy(formula)) {
					matchedFormulas.get(i).add(formula);
				}
			}
		}
		for(List<Formula> applyFormulas : Product.product(matchedFormulas)) {
			result.add(axiom.apply(applyFormulas));
		}
		if (result.size() != 0) {
			System.out.println(axiom.getName());
			System.out.println(result);
		}
		return result;
	}
	
	private static boolean equationTransitionCheck(Collection<Formula> assumptions, Formula conclusion) {
		if (! (conclusion instanceof EquationFormula)) {
			return false;
		}
		EquationFormula equationConclusion = (EquationFormula) conclusion;
		Map<EvaluatableTerm, List<EvaluatableTerm>> graph = constructEquationGraph(assumptions);
		Deque<EvaluatableTerm> que = new ArrayDeque<>();
		Set<EvaluatableTerm> visited = new HashSet<>();
		que.add(equationConclusion.getLeftSideHand());
		while (! que.isEmpty()) {
			EvaluatableTerm curNode = que.pollFirst();
			if (curNode.equals(equationConclusion.getRightSideHand())) {
				return true;
			}
			for (EvaluatableTerm nextNode : graph.getOrDefault(curNode, new ArrayList<>())) {
				if (visited.contains(nextNode)) {
					continue;
				}
				visited.add(nextNode);
				que.add(nextNode);
			}
		}
		return false;
	}
	
	private static Map<EvaluatableTerm, List<EvaluatableTerm>> constructEquationGraph(Collection<Formula> assumptions) {
		List<EquationFormula> equations = new ArrayList<>();
		for (Formula assumption : assumptions) {
			if (assumption instanceof EquationFormula) {
				equations.add((EquationFormula) assumption);
			}
		}
		
		Map<EvaluatableTerm, List<EvaluatableTerm>> graph = new HashMap<>();
		for (EquationFormula equation : equations) {
			EvaluatableTerm lsh = equation.getLeftSideHand();
			EvaluatableTerm rsh = equation.getRightSideHand();
			if (! graph.containsKey(lsh)) {
				graph.put(lsh, new ArrayList<>());
			}
			if (! graph.containsKey(rsh)) {
				graph.put(rsh, new ArrayList<>());
			}
			graph.get(lsh).add(rsh);
			graph.get(rsh).add(lsh);
		}
		return graph;
	}
}
