package inference.equivalence;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import models.algebra.Constant;
import models.algebra.Term;
import models.algebra.Variable;
import models.dataConstraintModel.DataConstraintModel;
import models.formulas.EquationFormula;
import models.terms.EvaluatableTerm;
import models.terms.LinearRightNormalizedType;
import models.terms.RDLTerm;
import models.terms.meta.MetaEvaluatableTermVariable;
import models.terms.meta.MetaRDLTerm;
import models.terms.meta.MetaResourceVariable;
import models.terms.meta.OrderConstraint;
import models.terms.meta.OrderVariableConstraint;

public class SemanticEquivalenceProofSystem {

	private Map<Integer, Set<SemanticEquivalenceRelation>> assumptions;
	private EquationFormula conclusion;
	private Map<SemanticEquivalenceRelation, SemanticEquivalenceRelation> proofGraph = new HashMap<>();
	int maxOrder = -1;
	
	private static MetaSemanticEquivalenceRelation rule1 = new MetaSemanticEquivalenceRelation(
			new MetaRDLTerm(
					new MetaResourceVariable(new Variable("v"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaResourceVariable(new Variable("v"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaEvaluatableTermVariable(new Variable("t"), new Variable("n"), LinearRightNormalizedType.LINEAR_RIGHT_NORMALIZED)
			),
			new MetaEvaluatableTermVariable(new Variable("t"), new Variable("n"), LinearRightNormalizedType.LINEAR_RIGHT_NORMALIZED),
			new Variable("n")
	);
	
	private static MetaSemanticEquivalenceRelation rule4_1 = new MetaSemanticEquivalenceRelation(
			new MetaEvaluatableTermVariable(new Variable("t"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1"))), LinearRightNormalizedType.LINEAR_RIGHT_NORMALIZED),
			new MetaEvaluatableTermVariable(new Variable("t'"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1"))), LinearRightNormalizedType.LINEAR_RIGHT_NORMALIZED),
			new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))
	);
	
	private static MetaSemanticEquivalenceRelation rule4_2 = new MetaSemanticEquivalenceRelation(
			new MetaRDLTerm(
					new MetaEvaluatableTermVariable(new Variable("t"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaResourceVariable(new Variable("v"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaEvaluatableTermVariable(new Variable("s"), OrderConstraint.LT, new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1"))))
			),
			new MetaRDLTerm(
					new MetaEvaluatableTermVariable(new Variable("t'"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaResourceVariable(new Variable("v"), new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1")))),
					new MetaEvaluatableTermVariable(new Variable("s"), OrderConstraint.LT, new Term(DataConstraintModel.add, List.of(new Variable("n"), new Constant("1"))))
			),
			new Variable("n")
	);
	
	public SemanticEquivalenceProofSystem(Collection<SemanticEquivalenceRelation> assumptions, EquationFormula conclusion) {
		this.assumptions = new HashMap<>();
		for(SemanticEquivalenceRelation assumption : assumptions) {
			maxOrder = Math.max(maxOrder, assumption.getOrder());
			if (! this.assumptions.containsKey(assumption.getOrder())) {
				this.assumptions.put(assumption.getOrder(), new HashSet<>());
			}
			this.assumptions.get(assumption.getOrder()).add(assumption.linearRightNormalized());
			this.conclusion = conclusion;
		}
	}
	
	public boolean proof() {
		Set<RDLTerm> existTerms = new HashSet<>();
		EquationFormula linearRightNormalizedConclusion = conclusion.linearRightNormalized();
		existTerms.addAll(linearRightNormalizedConclusion.getLeftSideHand().getSubTerms(EvaluatableTerm.class).values());
		existTerms.addAll(linearRightNormalizedConclusion.getRightSideHand().getSubTerms(EvaluatableTerm.class).values());
		SemanticEquivalenceRelation conclusionRelation = new SemanticEquivalenceRelation(conclusion.getLeftSideHand(), conclusion.getRightSideHand(), conclusion.getLeftSideHand().getOrder());
		SemanticEquivalenceRelation linearConclusionRelation = new SemanticEquivalenceRelation(linearRightNormalizedConclusion.getLeftSideHand(), linearRightNormalizedConclusion.getRightSideHand(), linearRightNormalizedConclusion.getLeftSideHand().getOrder());
		if (! linearRightNormalizedConclusion.equals(conclusion)) {
			proofGraph.put(conclusionRelation, linearConclusionRelation);
		}
		for(int key : assumptions.keySet()) {
			for(SemanticEquivalenceRelation relation : assumptions.get(key)) {
				existTerms.addAll(relation.getLeftSideHand().getSubTerms(EvaluatableTerm.class).values());
				existTerms.add(relation.getLeftSideHand());
				existTerms.addAll(relation.getRightSideHand().getSubTerms(EvaluatableTerm.class).values());
				existTerms.add(relation.getRightSideHand());
			}
		}
		
		
		for (int i = maxOrder; i > linearRightNormalizedConclusion.getLeftSideHand().getOrder(); i--) {
			Set<RDLTerm> apperTerms = new HashSet<>();
			for (SemanticEquivalenceRelation relation : assumptions.get(i)) {
				apperTerms.addAll(applyRule4(relation, existTerms));
			}
			existTerms.addAll(apperTerms);
		}
		
		
		List<SemanticEquivalenceRelation> proofResult = new ArrayList<>();
		SemanticEquivalenceRelation currentRelation = conclusionRelation;
		while(proofGraph.containsKey(currentRelation)) {
			proofResult.add(currentRelation);
			currentRelation = proofGraph.get(currentRelation);
		}
		proofResult.add(currentRelation);
		Collections.reverse(proofResult);
		for (int i = 0; i < proofResult.size(); i++) {
			SemanticEquivalenceRelation relation = proofResult.get(i);
			System.out.println(relation);
			if (i < proofResult.size() - 1) System.out.println("==============================================================");
		}
		return assumptions.get(0).contains(linearConclusionRelation);
	}
	
	private void applyRule1() {
		
	}
	
	private Set<RDLTerm> applyRule4(SemanticEquivalenceRelation relation, Set<RDLTerm> existTerms) {
		Map<Variable, RDLTerm> binding = new HashMap<>();
		Map<Variable, OrderVariableConstraint> orderConstraint = new HashMap<>();
		Set<RDLTerm> apperTerms = new HashSet<>();
		if (rule4_1.isMatchedBy(relation, binding, orderConstraint)) {
			Set<SemanticEquivalenceRelation> results = rule4_2.substitute(binding, orderConstraint, existTerms);
			for (SemanticEquivalenceRelation result : results) {
				if (! assumptions.containsKey(result.getOrder())) {
					assumptions.put(result.getOrder(), new HashSet<>());
				}
				proofGraph.put(result, relation);
				SemanticEquivalenceRelation linear = result.linearRightNormalized();
				assumptions.get(linear.getOrder()).add(linear);
				if (!linear.equals(result)) {
					proofGraph.put(linear, result);
				}
				apperTerms.addAll(linear.getLeftSideHand().getSubTerms(EvaluatableTerm.class).values());
				apperTerms.add(linear.getLeftSideHand());
				apperTerms.addAll(linear.getRightSideHand().getSubTerms(EvaluatableTerm.class).values());
				apperTerms.add(linear.getRightSideHand());
			}
		}
		return apperTerms;
	}
}
