package inference;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import models.formulas.EquationFormula;
import models.formulas.Formula;
import models.terms.EvaluatableTerm;

public class ProofSystem {

	public static boolean check(Collection<Formula> assumptions, Formula conclusion) {
		if (equationTransitionCheck(assumptions, conclusion)) {
			return true;
		}
		return false;
	}
	
	private static boolean equationTransitionCheck(Collection<Formula> assumptions, Formula conclusion) {
		if (! (conclusion instanceof EquationFormula)) {
			return false;
		}
		EquationFormula equationConclusion = (EquationFormula) conclusion;
		Map<EvaluatableTerm, List<EvaluatableTerm>> graph = constructEquationGraph(assumptions);
		Deque<EvaluatableTerm> que = new ArrayDeque<>();
		Set<EvaluatableTerm> visited = new HashSet<>();
		que.add(equationConclusion.getLeftSideHand());
		while (! que.isEmpty()) {
			EvaluatableTerm curNode = que.pollFirst();
			if (curNode.equals(equationConclusion.getRightSideHand())) {
				return true;
			}
			for (EvaluatableTerm nextNode : graph.getOrDefault(curNode, new ArrayList<>())) {
				if (visited.contains(nextNode)) {
					continue;
				}
				visited.add(nextNode);
				que.add(nextNode);
			}
		}
		return false;
	}
	
	private static Map<EvaluatableTerm, List<EvaluatableTerm>> constructEquationGraph(Collection<Formula> assumptions) {
		List<EquationFormula> equations = new ArrayList<>();
		for (Formula assumption : assumptions) {
			if (assumption instanceof EquationFormula) {
				equations.add((EquationFormula) assumption);
			}
		}
		
		Map<EvaluatableTerm, List<EvaluatableTerm>> graph = new HashMap<>();
		for (EquationFormula equation : equations) {
			EvaluatableTerm lsh = equation.getLeftSideHand();
			EvaluatableTerm rsh = equation.getRightSideHand();
			if (! graph.containsKey(lsh)) {
				graph.put(lsh, new ArrayList<>());
			}
			if (! graph.containsKey(rsh)) {
				graph.put(rsh, new ArrayList<>());
			}
			graph.get(lsh).add(rsh);
			graph.get(rsh).add(lsh);
		}
		return graph;
	}
}
