/*
 * Copyright (C) 2009-2010 Institute for Computational Biomedicine,
 *                    Weill Medical College of Cornell University
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.campagnelab.goby.modes;

import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import org.campagnelab.goby.algorithmic.data.Annotation;
import org.campagnelab.goby.algorithmic.data.ranges.Range;
import org.campagnelab.goby.algorithmic.data.ranges.Ranges;
import org.campagnelab.goby.alignments.AlignmentReaderImpl;
import org.campagnelab.goby.alignments.ConcatSortedAlignmentReader;
import org.campagnelab.goby.alignments.ReferenceLocation;
import edu.cornell.med.icb.identifier.DoubleIndexedIdentifier;
import it.unimi.dsi.fastutil.objects.*;
import org.apache.commons.io.IOUtils;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.*;

/**
 * Converts a compact alignment to plain text.
 *
 * @author Fabien Campagne
 */
public class SuggestPositionSlicesMode extends AbstractGobyMode {
    /**
     * The mode name.
     */
    private static final String MODE_NAME = "suggest-position-slices";

    /**
     * The mode description help text.
     */
    private static final String MODE_DESCRIPTION = "Suggest how to slice an alignment by position to yield roughly equally sized slices. ";

    /**
     * The output file.
     */
    private String outputFilename;

    /**
     * The basename of the compact alignment.
     */
    private String[] basenames;
    private int modulo;
    private int numberOfSlices;
    private int bufferLength;
    /**
     * Optional: file name for an annotation file.
     */
    private String annotationFilename;
    private int numBytesPerSlice = -1;
    private boolean useModulo;
    /**
     * When this switch is active, slices that would span two chromosomes are truncated to the end of the first
     * spanned chromosome. The next slice starts at the beginning of the next chromosome.
     */
    private boolean restrictPerChromosome;
    private int numBreakPointAdded = 0;


    @Override
    public String getModeName() {
        return MODE_NAME;
    }

    @Override
    public String getModeDescription() {
        return MODE_DESCRIPTION;
    }

    /**
     * Configure.
     *
     * @param args command line arguments
     * @return this object for chaining
     * @throws java.io.IOException                    error parsing
     * @throws com.martiansoftware.jsap.JSAPException error parsing
     */
    @Override
    public AbstractCommandLineMode configure(final String[] args)
            throws IOException, JSAPException {
        final JSAPResult jsapResult = parseJsapArguments(args);

        final String[] inputFiles = jsapResult.getStringArray("input");
        basenames = AlignmentReaderImpl.getBasenames(inputFiles);
        annotationFilename = jsapResult.getString("annotations");
        outputFilename = jsapResult.getString("output");
        modulo = jsapResult.getInt("modulo");
        numberOfSlices = jsapResult.getInt("number-of-slices");
        numBytesPerSlice = jsapResult.getInt("number-of-bytes");
        if (!(jsapResult.userSpecified("number-of-slices") ||
                jsapResult.userSpecified("number-of-bytes"))) {
            System.err.println("You must specify either --number-of-bytes or --number-of-slices");
            System.exit(1);
        }
        restrictPerChromosome = jsapResult.getBoolean("restrict-per-chromosome");
        if (jsapResult.userSpecified("number-of-slices") && jsapResult.userSpecified("number-of-bytes")) {
            System.err.println("You must select either number-of-slices or number-of-bytes, but not both. ");
            System.exit(1);
        }
        useModulo = jsapResult.userSpecified("number-of-slices");
        if (!useModulo) System.err.println("Splitting with " + numBytesPerSlice);
        return this;
    }


    /**
     * Suggests slices to process a large alignment file in parallel.
     *
     * @throws java.io.IOException error reading / writing
     */
    @Override
    public void execute() throws IOException {
        PrintStream stream = null;

        try {

            stream = outputFilename == null ? System.out
                    : new PrintStream(new FileOutputStream(outputFilename));


            ConcatSortedAlignmentReader input = new ConcatSortedAlignmentReader(basenames);
            input.readHeader();

            DoubleIndexedIdentifier ids = new DoubleIndexedIdentifier(input.getTargetIdentifiers());

            Ranges ranges = null;
            if (annotationFilename != null) {
                ranges = convertAnnotationsToRanges(annotationFilename, ids, input);
            }

            ReferenceLocation[] breakpoints = useModulo ? getReferenceLocationsWithModulo(stream, input, ids) :
                    getReferenceLocationsWithBytes(stream, input, ids);

            if (ranges != null) {
                adjustBreakpointsWithAnnotations(breakpoints, ranges);
            }
            breakpoints = removeDuplicates(breakpoints);
            if (restrictPerChromosome) {
                breakpoints = restrictPerChromosome(breakpoints, input);
            }
            assertPostcondition(breakpoints);
            for (int i = 0; i < breakpoints.length - 1; i++) {
                final int startTargetIndex = breakpoints[i].targetIndex;
                final int endTargetIndex = breakpoints[i + 1].targetIndex;
                final int startPosition = breakpoints[i].position;
                final int endPosition = breakpoints[i + 1].position;

                if (!restrictPerChromosome ||
                        (restrictPerChromosome && startTargetIndex == endTargetIndex)) {

                    stream.printf(String.format("%s\t%d\t%s,%d\t%s\t%d\t%s,%d%n",
                            ids.getId(startTargetIndex),
                            startPosition,
                            ids.getId(startTargetIndex),
                            startPosition,
                            ids.getId(endTargetIndex),
                            endPosition,
                            ids.getId(endTargetIndex),
                            endPosition));
                }
            }

        } finally {
            if (stream != System.out) {
                IOUtils.closeQuietly(stream);
            }
            System.out.println("Done");
        }
    }

    private ReferenceLocation[] removeDuplicates(ReferenceLocation[] breakpoints) {

        Set<ReferenceLocation> locations = new HashSet<ReferenceLocation>();
        for (ReferenceLocation breakpoint : breakpoints) {
            locations.add(breakpoint);
        }
        List<ReferenceLocation> locationsSorted = new ArrayList<ReferenceLocation>();
        locationsSorted.addAll(locations);
        Collections.sort(locationsSorted);

        return locationsSorted.toArray(new ReferenceLocation[locationsSorted.size()]);

    }

    /**
     * Check that slices do not overlap. If they do, raise an assertion.
     *
     * @param breakpoints
     */
    private void assertPostcondition(ReferenceLocation[] breakpoints) {

        for (int i=0;i<breakpoints.length-1;i++) {

            final int startTargetIndex = breakpoints[i].targetIndex;
            final int endTargetIndex = breakpoints[i + 1].targetIndex;
            final int startPosition = breakpoints[i].position;
            final int endPosition = breakpoints[i + 1].position;

            if (startTargetIndex == endTargetIndex) {

                assert endPosition > startPosition : "Positions must increase in slices: at index=" + i + " breakpoints[index]=" + startPosition + " breakpoints[index]=" + endPosition;
            }
        }
    }

    private ReferenceLocation[] restrictPerChromosome(ReferenceLocation[] breakpoints, ConcatSortedAlignmentReader reader) {
        ObjectArrayList<ReferenceLocation> result = new ObjectArrayList<ReferenceLocation>();
        int lastTargetIndex = -1;
        int index = 0;
        for (ReferenceLocation breakpoint : breakpoints) {
            if (breakpoint.targetIndex != lastTargetIndex) {
                // we switch to a new chromosome, introduce a new breakpoint at the end of the previous chromosome:
                if (lastTargetIndex != -1) {
                    result.add(new ReferenceLocation(lastTargetIndex, reader.getTargetLength(lastTargetIndex) - 1));
                    if (breakpoint.position != 0){
                        result.add(new ReferenceLocation(breakpoint.targetIndex, 0));
                    }
                    System.out.println("Adding breakpoint at end of " + lastTargetIndex);
                    numBreakPointAdded++;
                }
                lastTargetIndex = breakpoint.targetIndex;
            }
            result.add(breakpoint);
            index++;
        }
        return result.toArray(new ReferenceLocation[result.size()]);
    }

    private ReferenceLocation[] getReferenceLocationsWithBytes(PrintStream stream, ConcatSortedAlignmentReader input, DoubleIndexedIdentifier ids) throws IOException {
        ObjectList<ReferenceLocation> locations = input.getLocationsByBytes(numBytesPerSlice);
        numberOfSlices = locations.size();
        return prepareBreakpoints(stream, input, ids, locations);
    }

    private ReferenceLocation[] getReferenceLocationsWithModulo(PrintStream stream, ConcatSortedAlignmentReader input, DoubleIndexedIdentifier ids) throws IOException {
        ObjectList<ReferenceLocation> locations = input.getLocations(modulo);

        if (locations.size() < numberOfSlices) {
            numberOfSlices = locations.size();
        }

        return prepareBreakpoints(stream, input, ids, locations);
    }

    private ReferenceLocation[] prepareBreakpoints(PrintStream stream, ConcatSortedAlignmentReader input, DoubleIndexedIdentifier ids, ObjectList<ReferenceLocation> locations) {
        ReferenceLocation first;
        ReferenceLocation[] breakpoints = new ReferenceLocation[numberOfSlices + 1];
        breakpoints[0] = first = locations.get(0);
        locations.remove(first);
        stream.println("targetIdStart\t%positionStart\tstart:(ref,pos)\ttargetIdEnd\t%positionEnd\tend:(ref,pos)");
        for (int i = 0; i < numberOfSlices - 1; i++) {
            breakpoints[i + 1] = locations.get(locations.size() / (numberOfSlices - 1) * i);
        }

        // largest position in the last reference sequence:
        final int lastTargetIndex = ids.size() - 1;
        breakpoints[breakpoints.length - 1] = new ReferenceLocation(lastTargetIndex, input.getTargetLength(lastTargetIndex));
        return breakpoints;
    }

    private void adjustBreakpointsWithAnnotations(final ReferenceLocation[] breakpoints, final Ranges ranges) {
        ranges.order();
        for (final ReferenceLocation breakpoint : breakpoints) {
            int referenceIndex = breakpoint.targetIndex;
            int position = breakpoint.position;
            Range nonOverlapping = ranges.findNonOverlappingRange(referenceIndex, position);

            // change the breakpoint to a position in the middle of the non-overlapping segment:
            breakpoint.position = (nonOverlapping.min + nonOverlapping.max) / 2;

        }
    }

    private Ranges convertAnnotationsToRanges(String annotationFilename, DoubleIndexedIdentifier ids, ConcatSortedAlignmentReader input) throws IOException {
        Ranges ranges = new Ranges();
        if (annotationFilename == null) return null;
        Object2ObjectMap<String, ObjectList<Annotation>> annotations = CompactAlignmentToAnnotationCountsMode.readAnnotations(annotationFilename);
        Object2IntMap<String> starts = new Object2IntLinkedOpenHashMap<String>();
        Object2IntMap<String> ends = new Object2IntLinkedOpenHashMap<String>();
        Object2IntMap<String> refIndices = new Object2IntLinkedOpenHashMap<String>();
        starts.defaultReturnValue(Integer.MAX_VALUE);
        ends.defaultReturnValue(Integer.MIN_VALUE);

        for (ObjectList<Annotation> value : annotations.values()) {

            for (Annotation ann : value) {

                int min = Integer.MAX_VALUE;
                int max = Integer.MIN_VALUE;
                String chromosome = ann.getChromosome();
                String id = ann.getId();
                final int referenceIndex = ids.getIndex(chromosome);
                refIndices.put(id, referenceIndex);
                min = starts.getInt(id);
                max = ends.getInt(id);

                starts.put(id, Math.min(min, ann.getStart()));
                ends.put(id, Math.max(max, ann.getEnd()));

                assert chromosome != null : " annotation must contain non null chromosome.";

            }
        }
        for (final String id : starts.keySet()) {
            final Range range = new Range();
            final int referenceIndex = refIndices.getInt(id);
            range.min = starts.getInt(id);
            range.max = ends.getInt(id);
            ranges.add(range, referenceIndex);
        }

        ranges.order();
        return ranges;
    }

    /**
     * Main method.
     *
     * @param args command line args.
     * @throws com.martiansoftware.jsap.JSAPException error parsing
     * @throws java.io.IOException                    error parsing or executing.
     */

    public static void main(final String[] args) throws JSAPException, IOException {
        new SuggestPositionSlicesMode().configure(args).execute();
    }
}