/**
 * <copyright>
 *
 * Copyright (c) 2011 modelevolution.org
 * All rights reserved. This program and the accompanying materials are
 * made available under the terms of the Eclipse Public License v1.0 which
 * accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * </copyright>
 */
package org.modelevolution.multiview.diff.encoding.engine.impl;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Logger;

import org.eclipse.emf.common.util.BasicEList;
import org.eclipse.emf.common.util.EList;
import org.eclipse.emf.compare.match.metamodel.Side;
import org.eclipse.emf.ecore.EObject;
import org.modelevolution.multiview.Lifeline;
import org.modelevolution.multiview.Message;
import org.modelevolution.multiview.MultiviewModel;
import org.modelevolution.multiview.ReceiveEvent;
import org.modelevolution.multiview.Region;
import org.modelevolution.multiview.SendEvent;
import org.modelevolution.multiview.State;
import org.modelevolution.multiview.Transition;
import org.modelevolution.multiview.conflictreport.ConflictFragment;
import org.modelevolution.multiview.conflictreport.ConflictReport;
import org.modelevolution.multiview.conflictreport.ConflictReportFactory;
import org.modelevolution.multiview.diff.encoding.engine.ICnfParser;
import org.modelevolution.multiview.diff.encoding.engine.ILogObservable;
import org.modelevolution.multiview.merge.MergePosition;
import org.modelevolution.multiview.merge.MergeType;
import org.modelevolution.multiview.util.MessageComparator;
import org.modelversioning.conflicts.detection.impl.ThreeWayDiffProvider;

/**
 * @author <a href="mailto:brosch@big.tuwien.ac.at">Petra Brosch</a>
 * @author <a href="mailto:widl@big.tuwien.ac.at">Magdalena Widl</a>
 * 
 */
public class Sequence2SATDIMACSEncodingEngine extends SequenceEncodingEngine
		implements ICnfParser, ILogObservable {
	private static final Logger logger = Logger
			.getLogger("org.modelevolution.multiview.diff.encoding.engine.impl");

	/**
	 * The messages to merge
	 */
	protected TreeMap<MergePosition, SortedSet<Message>> messages;

	protected LinkedHashMap<Message, MergeType> messageToMergeType;

	protected LinkedHashMap<Message, HashSet<Integer>> messageToPositions;
	protected LinkedHashMap<Integer, HashSet<Message>> positionToMessages;

	protected LinkedHashMap<Message, HashSet<State>> messageToSourceStates;
	protected LinkedHashMap<Message, HashSet<State>> messageToTargetStates;

	protected int startPos;
	protected int endPos;

	public static String VAR_EXT_MESSAGE = "M";
	public static String VAR_EXT_STATE = "S";

	private LinkedHashMap<String, Integer> varNames;
	private LinkedHashMap<Integer, String> variableMap;
	private int noClauses = 0;

	public Sequence2SATDIMACSEncodingEngine() {
		super();
		varNames = new LinkedHashMap<String, Integer>();
	}

	private int getVarNumber(String var) {
		
		if (!varNames.containsKey(var)) {
			varNames.put(var, varNames.size()+1);
		}
		return varNames.get(var);
	}
	

	private int getTseitinVar() {
		int tseitin = varNames.size()+1;
		String var = tseitin+"";
		varNames.put(var+"", tseitin);
		
		return varNames.get(var);
	}

	private void setVariableMap() {
		variableMap = new LinkedHashMap<Integer, String>();
		for (String key : varNames.keySet()) {
			variableMap.put(varNames.get(key), key);
		}
	}

	@Override
	public String getVariable(int var) {
		return variableMap.get(var);
	}

	@Override
	public void generateEncoding(ThreeWayDiffProvider threeWayDiff, File output)
			throws IOException {
		super.generateEncoding(threeWayDiff, output);

		long startTime = System.currentTimeMillis();
		initMessages(threeWayDiff);
		logObservations.put(LogObservation.DIFF_POST, System.currentTimeMillis() - startTime);
		
		startTime = System.currentTimeMillis();
		FileWriter fstream = new FileWriter(output.getCanonicalPath() + ".tmp");
		BufferedWriter out = new BufferedWriter(fstream);

		 out.write("c Possible positions\n");
		printPossiblePositions(out);

		 out.write("c Relative positions\n");
		printRelativePositions(out);
		
		 out.write("c Positions\n");
		printPositions(out);

		 out.write("c State machines\n");
		printStateMachines(out);

		 out.write("c State machines XOR \n");
		printStatesXOR(out);

		 out.write("c Each index - source/target one state \n");
		printConstraintsStates(out);

		 out.write("c Path constraints \n");
		printConstraintsPaths(out);

		out.close();

		FileWriter fstream2 = new FileWriter(output.getCanonicalPath());
		BufferedWriter out2 = new BufferedWriter(fstream2);

		printMapping(out2);
		out2.write("p cnf " + (varNames.size()+1) + " " + noClauses + "\n");

		FileReader f = new FileReader(output.getCanonicalPath() + ".tmp");
		BufferedReader in = new BufferedReader(f);

		String line;

		while ((line = in.readLine()) != null) {
			out2.write(line + "\n");
		}

		out2.close();
		in.close();
		(new File(output.getCanonicalPath() + ".tmp")).delete();

		for (Lifeline l : lifelines) {
			l.unsetDummyStatemachine();
		}

		setVariableMap();
		logObservations.put(LogObservation.ENC, System.currentTimeMillis() - startTime);
		
		logObservations.put(LogObservation.CNF_VAR_COUNT, new Long(varNames.size()));
		logObservations.put(LogObservation.CNF_CLAUSE_COUNT, new Long(noClauses));
	}


	/*
	 * Constraint: The resulting sequence of messages must represent a path in
	 * each state machine
	 */
	private void printConstraintsPaths(BufferedWriter out) throws IOException {
		int tseitin;
		// out.write("# Path constraints \n");

		for (int position = startPos; position < endPos; position++) {

			for (Lifeline ll : lifelines) {

				Region sm = ll.getClass_().getStatemachine();

				if (sm.getStates().size() > 1) {

					/* Case next one */
					out.write("c Case next one\n");
					for (State stateStart : sm.getStates()) {
						tseitin = getTseitinVar();
						out.write("c tseitin " +tseitin+"\n");

						out.write("-"+getVarNumber(getVariable(ll, stateStart, "t",
										position))+" "+ tseitin + endClause());
		

						for (State stateIllegal : sm.getStates()) {
							if (!stateStart.equals(stateIllegal)) {
								out.write("-"+tseitin + " -"+getVarNumber(getVariable(ll,
												stateIllegal, "s", position + 1)) + endClause());
							}
						}

						/* Case next n */
						for (int j = position + 1; j < endPos; j++) {

							out.write("-"
									+ getVarNumber(getVariable(ll, stateStart,
											"t", position)));

							for (int k = position + 1; k <= j; k++) {

								out.write(" "
										+ getVarNumber(getVariable(ll,
												stateStart, "s", k)));

							}
							tseitin = getTseitinVar();
							out.write(" " + tseitin + endClause());

							for (State stateIllegal : sm.getStates()) {
								if (!stateStart.equals(stateIllegal)) {

									out.write("-" + tseitin+" -"
											+ getVarNumber(getVariable(ll,
													stateIllegal, "s", j + 1)) + endClause());
								}
							}

						}
					}
				}

			}
		}
	}

	private String endClause() {
		noClauses++;
		return " 0"+"\n";
	}

	/*
	 * Constraint: Before/after message receipt, some state machine must be in
	 * some state
	 */
	private void printConstraintsStates(BufferedWriter out) throws IOException {

		for (int position : positionToMessages.keySet()) {

			for (Lifeline ll : lifelines) {
				Region sm = ll.getClass_().getStatemachine();

				for (State state : sm.getStates()) {
					out.write(getVarNumber(getVariable(ll, state, "s", position))
							+ " ");
				}
			}
			out.write(endClause());

			for (Lifeline ll : lifelines) {
				Region sm = ll.getClass_().getStatemachine();

				for (State state : sm.getStates()) {
					out.write(getVarNumber(getVariable(ll, state, "t", position))
							+ " ");
				}
			}
			out.write(endClause());
		}

	}

	/* State machines */
	private void printStateMachines(BufferedWriter out) throws IOException {

		for (Message m : messageToPositions.keySet()) {

			for (int position : messageToPositions.get(m)) {

				Set<TseitinState> tseitinVars = new HashSet<TseitinState>();

				
				out.write("-"+getVarNumber(getVariable(
								messageToMergeType.get(m), m, position))+ " ");
	
				
				// Target and source receive
//				for (Lifeline lifeline : lifelines){	
					Lifeline lifeline = m.getReceiver().getLifeline();
					Region sm = lifeline.getClass_().getStatemachine();
					for (State state : sm.getStates()){	
						for (Transition transition : state.getOutgoing()){
							
							if(transition.getTrigger().getName().equals(m.getBody().getName())){
								
								int tseitin = getTseitinVar();	
								
								tseitinVars.add(new TseitinState(tseitin,getVarNumber(getVariable(lifeline, transition.getSource(), "s", position))));
								tseitinVars.add(new TseitinState(tseitin,getVarNumber(getVariable(lifeline, transition.getTarget(), "t", position))));

								out.write(tseitin+" ");

							}	
						}
					}
//				}
				out.write(endClause());
				
				for (TseitinState tseitin : tseitinVars){
					out.write("-"+tseitin.getTseitin()+" "+tseitin.getState()+endClause());
				}
			}
		}

	}
	
	/* Each source/target exactly one state */
	private void printStatesXOR(BufferedWriter out) throws IOException {

		// out.write("# State machines\n");

		for (int position = startPos; position <= endPos; position++) {

			for (Lifeline ll1 : lifelines) {
				
				Region sm1 = ll1.getClass_().getStatemachine();

				for (State state1 : sm1.getStates()) {
					for (Lifeline ll2 : lifelines) {

						Region sm2 = ll2.getClass_().getStatemachine();

						for (State state2 : sm2.getStates()) {

							if (!(ll1.equals(ll2) && state1.equals(state2))) {
								
								out.write("-"
										+ getVarNumber(getVariable(ll1, state1,
												"t", position))
										+ " -"
										+ getVarNumber(getVariable(ll2, state2,
												"t", position)) + endClause());
							}
						}
					}
				}
			}

			for (Lifeline ll1 : lifelines) {
				Region sm1 = ll1.getClass_().getStatemachine();

				for (State state1 : sm1.getStates()) {
					for (Lifeline ll2 : lifelines) {

						Region sm2 = ll2.getClass_().getStatemachine();

						for (State state2 : sm2.getStates()) {

							if (!(ll1.equals(ll2) && state1.equals(state2))) {

								out.write("-"
										+ getVarNumber(getVariable(ll1, state1,
												"s", position))
										+ " -"
										+ getVarNumber(getVariable(ll2, state2,
												"s", position)) + endClause());
							}
						}
					}
				}
			}
		}

	}

	/* Positioning part (1) (possible positions for each message) */
	private void printPossiblePositions(BufferedWriter out) throws IOException {

		// out.write("# Possible positions for each message\n");

		for (Message m : messageToPositions.keySet()) {
			MergeType mergeType = messageToMergeType.get(m);

			// System.out.print(m.getBody().getName() + "\t");

			for (int index : messageToPositions.get(m)) {

				// System.out.print("index, ");

				out.write(getVarNumber(getVariable(mergeType, m, index)) + " ");
			}

			// System.out.println();
			out.write(endClause());

		}
	}

	/* Positioning part (2) (relative positions of messages) */
	private void printRelativePositions(BufferedWriter out) throws IOException {

		// out.write("# Relative positions of messages\n");

		for (Message m : messageToPositions.keySet()) {
			MergeType mergeType = messageToMergeType.get(m);

			if (messageToPositions.get(m).size() > 1) {

				for (int startIndex : messageToPositions.get(m)) {

					for (Message n : messageToPositions.keySet()) {
						if ((n.getPosition() > m.getPosition())
								&& messageToMergeType.get(m) == messageToMergeType
										.get(n)
								&& messageToPositions.get(n).contains(
										startIndex)) {

							// HashSet<String> vars = new HashSet<String>();

							out.write("-"
									+ getVarNumber(getVariable(mergeType, m,
											startIndex)) + " ");
							for (int followIndex : messageToPositions.get(n)) {
								if (followIndex > startIndex)
									// vars.add(getVariable(mergeType,n.getPosition(),followIndex));
									out.write(getVarNumber(getVariable(
											mergeType, n, followIndex)) + " ");
							}
							out.write(endClause());
						

						}
					}
				}
			}
		}
	}

	/* Positioning part (3) (each position exactly one message) */
	private void printPositions(BufferedWriter out) throws IOException {

		for (int position : positionToMessages.keySet()) {

			EList<Message> currMessages = new BasicEList<Message>();

			for (Message m : positionToMessages.get(position)) {
				out.write(getVarNumber(getVariable(messageToMergeType.get(m),
						m, position)) + " ");
				currMessages.add(m);
			}
			out.write(endClause());

			for (int i = 0; i < currMessages.size(); i++) {
				for (int j = i; j < currMessages.size(); j++) {

					// FIXME this should be !m.equals(n)
					if (currMessages.get(i).getPosition() != currMessages
							.get(j).getPosition()
							|| messageToMergeType.get(currMessages.get(i)) != messageToMergeType
									.get(currMessages.get(j))) {

						out.write("-"
								+ getVarNumber(getVariable(messageToMergeType
										.get(currMessages.get(i)), currMessages
										.get(i), position))
								+ " -"
								+ getVarNumber(getVariable(messageToMergeType
										.get(currMessages.get(j)), currMessages
										.get(j), position)) + endClause());
					}

				}
			}
		}
	}

	private String getVariable(MergeType mergeType, Message m, int newIndex) {

		return newIndex + "-" + mergeType + "-" + m.getPosition() + "-"
				+ m.getBody().getName() + "_" + VAR_EXT_MESSAGE;
	}

	private String getVariable(Lifeline lifeline, State state, String p,
			int index) {
		
		return index + "-" + p + "-" + lifeline.getName() + "-"
				+ state.getName() + "_" + VAR_EXT_STATE;
	}

	/**
	 * Analyzes differences and initializes the list of messages to merge
	 */
	private void initMessages(ThreeWayDiffProvider threeWayDiff) {
		messages = new TreeMap<MergePosition, SortedSet<Message>>();
		messageToMergeType = new LinkedHashMap<Message, MergeType>();
		messageToPositions = new LinkedHashMap<Message, HashSet<Integer>>();
		positionToMessages = new LinkedHashMap<Integer, HashSet<Message>>();

		messageToSourceStates = new LinkedHashMap<Message, HashSet<State>>();
		messageToTargetStates = new LinkedHashMap<Message, HashSet<State>>();

		LinkedHashMap<Message, SortedSet<Message>> leftMergeGroups = new LinkedHashMap<Message, SortedSet<Message>>();
		LinkedHashMap<Message, SortedSet<Message>> rightMergeGroups = new LinkedHashMap<Message, SortedSet<Message>>();

		EList<EObject> leftAddedElements = threeWayDiff.getAddedEObjects(
				Side.LEFT, true);
		EList<EObject> rightAddedElements = threeWayDiff.getAddedEObjects(
				Side.RIGHT, true);

		initMergeGroups(leftAddedElements, leftMergeGroups, threeWayDiff,
				Side.LEFT);
		initMergeGroups(rightAddedElements, rightMergeGroups, threeWayDiff,
				Side.RIGHT);

		// calculate merge positions
		EList<Message> originMessages = ((MultiviewModel) threeWayDiff
				.getOriginModel().get(0)).getSequenceview()
				.getOrderedMessages();
		int shift = 1;

		shift += getMergeFragment(leftMergeGroups, rightMergeGroups, null, 0, 0);

		if (shift > 1 && conflictReport != null) {
			if (originMessages.size() > 0) {
				addConflictFragment(null, originMessages.get(0));
			} else if (originMessages.size() == 0) {
				addConflictFragment(null, null);
			}
		}

		for (int i = 0; i < originMessages.size(); i++) {
			Message originMessage = originMessages.get(i);
			SortedSet<Message> m = new TreeSet<Message>();
			m.add(originMessage);
			messages.put(new MergePosition(i + shift, MergeType.O), m);

			int oldShift = shift;
			shift = getMergeFragment(leftMergeGroups, rightMergeGroups,
					originMessage, shift, i);

			if (shift != oldShift && conflictReport != null) {
				Message nextOriginMsg = (i < originMessages.size() - 1) ? originMessages
						.get(i + 1) : null;
				addConflictFragment(originMessage, nextOriginMsg);
			}
		}

		for (MergePosition mp : messages.keySet()) {

			for (Message m : messages.get(mp)) {
				if (messageToPositions.containsKey(m)) {
					messageToPositions.get(m).add(mp.getIndex());
				} else {
					HashSet<Integer> positions = new HashSet<Integer>();
					positions.add(mp.getIndex());
					messageToPositions.put(m, positions);
				}

				messageToMergeType.put(m, mp.getType());
			}
		}

		for (MergePosition mp : messages.keySet()) {

			for (Message m : messages.get(mp)) {

				if (positionToMessages.containsKey(mp.getIndex())) {
					positionToMessages.get(mp.getIndex()).add(m);
				} else {
					positionToMessages.put(mp.getIndex(),
							new HashSet<Message>());
					positionToMessages.get(mp.getIndex()).add(m);
				}
			}
		}

		for (Message m : messageToPositions.keySet()) {

			HashSet<State> sourceStates = new HashSet<State>();
			HashSet<State> targetStates = new HashSet<State>();

			Lifeline ll = m.getReceiver().getLifeline();
			if (ll.getClass_() == null
					|| ll.getClass_().getStatemachine() == null) {
				ll.initDummyStatemachine();
			}

			Region smReceiver = ll.getClass_().getStatemachine();

			for (Transition t : m.getBody().getTriggers()) {
				if (smReceiver.getStates().contains(t.getSource())) {
					sourceStates.add(t.getSource());
				}
				if (smReceiver.getStates().contains(t.getTarget())) {
					targetStates.add(t.getTarget());
				}
			}

			messageToSourceStates.put(m, sourceStates);
			messageToTargetStates.put(m, targetStates);
		}

		startPos = 1;
		endPos = positionToMessages.size();

	}

	/**
	 * Adds a {@link ConflictFragment} to the {@link ConflictReport}
	 * 
	 * @param lastOriginMsg
	 *            The last origin message before the conflict
	 * @param nextOriginMsg
	 *            The next origin message after the conflict
	 */
	private void addConflictFragment(Message lastOriginMsg,
			Message nextOriginMsg) {
		if (conflictReport != null) {
			ConflictFragment confFrag = ConflictReportFactory.eINSTANCE
					.createConflictFragment();
			confFrag.setLastOrigin(lastOriginMsg);
			confFrag.setNextOrigin(nextOriginMsg);

			conflictReport.getConflicts().add(confFrag);
		}
	}

	/**
	 * Adds a merge fragment to the messages to merge
	 * 
	 * @param leftMergeGroups
	 * @param rightMergeGroups
	 * @param originMessage
	 * @param shift
	 * @return shift
	 */
	private int getMergeFragment(
			LinkedHashMap<Message, SortedSet<Message>> leftMergeGroups,
			LinkedHashMap<Message, SortedSet<Message>> rightMergeGroups,
			Message originMessage, int shift, int pos) {
		if (leftMergeGroups.get(originMessage) != null
				|| rightMergeGroups.get(originMessage) != null) {
			int leftGroupsSize = 0;
			int rightGroupsSize = 0;

			if (leftMergeGroups.get(originMessage) != null) {
				leftGroupsSize = leftMergeGroups.get(originMessage).size();
			}

			if (rightMergeGroups.get(originMessage) != null) {
				rightGroupsSize = rightMergeGroups.get(originMessage).size();
			}

			int mergeGroupsSize = leftGroupsSize + rightGroupsSize;

			for (int i = 1; i <= mergeGroupsSize; i++) {
				shift++;
				if (leftMergeGroups.get(originMessage) != null) {
					messages.put(new MergePosition(shift + pos, MergeType.L),
							leftMergeGroups.get(originMessage));
				}
				if (rightMergeGroups.get(originMessage) != null) {
					messages.put(new MergePosition(shift + pos, MergeType.R),
							rightMergeGroups.get(originMessage));
				}
			}
		}
		return shift;
	}

	/**
	 * Collects added messages to merge groups
	 * 
	 * @param addedElements
	 * @param mergeGroups
	 * @param threeWayDiff
	 */
	private void initMergeGroups(EList<EObject> addedElements,
			LinkedHashMap<Message, SortedSet<Message>> mergeGroups,
			ThreeWayDiffProvider threeWayDiff, Side side) {
		for (EObject addElement : addedElements) {
			Message addedMessage = null;

			// get the change's messages to find previous origin message
			if (addElement instanceof SendEvent) {
				addedMessage = ((SendEvent) addElement).getMessage();
			} else if (addElement instanceof ReceiveEvent) {
				addedMessage = ((ReceiveEvent) addElement).getMessage();
			} else {
				continue;
			}

			Message lastOrigin = getLastOrigin(addedMessage, side, threeWayDiff);

			if (mergeGroups.containsKey(lastOrigin)) {
				mergeGroups.get(lastOrigin).add(addedMessage);
			} else {
				SortedSet<Message> addedMessages = new TreeSet<Message>(
						new MessageComparator());
				addedMessages.add(addedMessage);
				mergeGroups.put(lastOrigin, addedMessages);
			}
		}
	}

	/**
	 * Determines the last previous original message
	 * 
	 * @param message
	 * @param side
	 * @param threeWayDiff
	 * @return originMessage
	 */
	private Message getLastOrigin(Message message, Side side,
			ThreeWayDiffProvider threeWayDiff) {
		Message previous = message.getPrevious();
		if (previous == null)
			return null;

		EObject originPrevious = threeWayDiff.getMatchingEObject(previous,
				side, true);
		if (originPrevious == null) {
			return getLastOrigin(previous, side, threeWayDiff);
		}

		else
			return (Message) originPrevious;
	}

	private SortedSet<Message> messagesAtIndexType(int index, MergeType mt) {
		MergePosition origin = new MergePosition(index, mt);
		return messages.get(origin);
	}
	

	private void printMapping(BufferedWriter out) {

		for (String v : varNames.keySet()) {
			try {
				out.write("c " + v + "\t\t" + (varNames.get(v)) + "\n");
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
	}


	@Override
	public Long getLogObservation(LogObservation observation) {
		return logObservations.get(observation);
	}
}
