#! /usr/bin/env python
# Author: Martin C. Frith 2019
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import print_function

import gzip
import itertools
import logging
import optparse
import string
import sys
from operator import itemgetter

def openFile(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def connectedComponent(adjacencyList, isNew, i):
    stack = [i]
    isNew[i] = False
    while stack:
        j = stack.pop()
        yield j
        for k in adjacencyList[j]:
            if isNew[k]:
                stack.append(k)
                isNew[k] = False

def connectedComponents(adjacencyList):
    isNew = [True] * len(adjacencyList)
    for i in range(len(adjacencyList)):
        if isNew[i]:
            yield sorted(connectedComponent(adjacencyList, isNew, i))

def isSeqRange(text):
    if text.count(":") == 1:
        sequenceName, r = text.split(":")
        for i in "<>":
            if r.count(i) == 1:
                beg, end = r.split(i)
                return beg.isdigit() and end.isdigit()
    return False

def isAllSeqRanges(fields):
    return all(map(isSeqRange, fields))

def seqRangeFromText(text):
    sequenceName, r = text.split(":")
    parts = r.split("<" if "<" in r else ">")
    beg, end = map(int, parts)
    if beg == end:
        raise Exception("zero-length segment: " + text)
    return sequenceName, beg, end

def rearrangementsFromLines(lines):
    state = 0
    for line in lines:
        isSharp = line.startswith("# ")
        fields = line.split()
        if state == 0:
            if isSharp and len(fields) == 2:
                groupName = fields[1]
                state = 1
        elif state == 1:
            if isSharp and len(fields) == 2:
                groupName = fields[1]
            elif isSharp and len(fields) > 2 and isAllSeqRanges(fields[2:]):
                seqRanges = [seqRangeFromText(i) for i in fields[2:]]
                state = 2
        elif state == 2:
            if isSharp and len(fields) > 1 and isAllSeqRanges(fields[1:]):
                seqRanges.extend(seqRangeFromText(i) for i in fields[1:])
            else:
                yield groupName, seqRanges
                state = 3
        if len(fields) < 2:
            state = 0
    if state == 2:
        yield groupName, seqRanges

def wantedRearrangements(opts, rearrangements):
    groupIds = set(opts.groups.split(","))
    for i in rearrangements:
        groupName, seqRanges = i
        x = groupName.split("-")[0]
        y = x[5:]
        if groupName in groupIds or x in groupIds or y in groupIds:
            yield i

def nodesFromRearrangements(rearrangements):
    for rearrangementNum, r in enumerate(rearrangements):
        _, segments = r
        chrom0, beg0, end0 = segments[0]
        mid0 = (beg0 + end0) // 2
        isLowerEnd0 = (beg0 < end0)
        yield chrom0, mid0, isLowerEnd0, rearrangementNum, 0
        chrom1, beg1, end1 = segments[-1]
        mid1 = (beg1 + end1) // 2
        isLowerEnd1 = (end1 < beg1)
        yield chrom1, mid1, isLowerEnd1, rearrangementNum, 1

def showNodesOfOneChromosome(rearrangements, sortedNodesOfOneChromosome):
    upperAndLowerEndSymbols = "=]", "[="
    chromName = sortedNodesOfOneChromosome[0][0]
    print(chromName + ":", file=sys.stderr)
    for node in sortedNodesOfOneChromosome:
        pos, isLowerEnd, rearrangementNum = node[1:4]
        endSymbol = upperAndLowerEndSymbols[isLowerEnd]
        name = rearrangements[rearrangementNum][0]
        print("{0:9}  {1}  {2}".format(pos, endSymbol, name), file=sys.stderr)
    print(file=sys.stderr)

# See: https://en.wikipedia.org/wiki/Matching_(graph_theory)
def numOfMaximumMatchings(sortedNodesOfOneChromosome):
    upperEndCount = 0
    matchingCounts = [1]  # matchingCounts[i] = count of matchings with i edges
    for node in sortedNodesOfOneChromosome:
        isLowerEnd = node[2]
        if isLowerEnd:
            if upperEndCount >= len(matchingCounts):
                matchingCounts.append(0)
            for j in reversed(range(1, len(matchingCounts))):
                i = j - 1
                matchingCounts[j] += matchingCounts[i] * (upperEndCount - i)
        else:
            upperEndCount += 1
    return matchingCounts[-1]

def edgesFromNodes(sortedNodesOfOneChromosome):
    upperEnds = []
    for node in sortedNodesOfOneChromosome:
        isLowerEnd = node[2]
        if isLowerEnd:
            if upperEnds:
                n = upperEnds.pop()
                yield n[3:5], node[3:5]
        else:
            upperEnds.append(node)

def allEdgeSetsFromNodes(sortedNodesOfOneChromosome):
    upperEnds = []
    theNullMatching = [], []
    matchings = [theNullMatching]
    for i, node in enumerate(sortedNodesOfOneChromosome):
        isLowerEnd = node[2]
        if isLowerEnd:
            newMatchings = []
            for uppers, lowers in matchings:
                for u in upperEnds:
                    if u not in uppers:
                        n = uppers + [u], lowers + [i]
                        newMatchings.append(n)
            matchings.extend(newMatchings)
        else:
            upperEnds.append(i)
    theMax = max(len(i[0]) for i in matchings)
    for uppers, lowers in matchings:
        if len(uppers) == theMax:
            z = zip(uppers, lowers)
            yield [(sortedNodesOfOneChromosome[i][3:5],
                    sortedNodesOfOneChromosome[j][3:5]) for i, j in z]

def nextNode(edges, node):
    for x, y in edges:
        if x == node:
            return y
        if y == node:
            return x
    return None

def takeLinkedNodes(edges, isUsed, rearrangementNum, endNum):
    node = rearrangementNum, endNum
    while True:
        node = nextNode(edges, node)
        if not node:
            break
        rNum, eNum = node
        if isUsed[rNum]:
            break
        isFlipped = int(eNum == endNum)
        yield rNum, isFlipped
        isUsed[rNum] = True
        node = rNum, 1 - eNum

def rearrangementChainsFromEdges(numOfRearrangements, edges):
    isUsed = [False] * numOfRearrangements
    for rearrangementNum in range(numOfRearrangements):
        if not isUsed[rearrangementNum]:
            isFlipped = 0
            thisNode = rearrangementNum, isFlipped
            isUsed[rearrangementNum] = True
            prevNodes = takeLinkedNodes(edges, isUsed, rearrangementNum, 0)
            nextNodes = takeLinkedNodes(edges, isUsed, rearrangementNum, 1)
            yield list(prevNodes)[::-1] + [thisNode] + list(nextNodes)

def isCircular(rearrangementChain, edges):
    return nextNode(edges, rearrangementChain[0])

def printSegment(chrom, beg, end):
    sign = "<" if beg < end else ">"
    print(chrom, beg, sign, end, sep="\t")

def orientedSegments(rearrangements, rearrangementChainItem):
    rearrangementNum, isFlipped = rearrangementChainItem
    junk, segments = rearrangements[rearrangementNum]
    if isFlipped:
        return [(c, end, beg) for c, beg, end in reversed(segments)]
    return segments

def isRevSegment(rearrangements, rearrangementChain, index):
    segments = orientedSegments(rearrangements, rearrangementChain[index])
    junk, beg, end = segments[index]
    return beg > end

def flippedRearrangementChain(rearrangements, rearrangementChain):
    isRevHead = isRevSegment(rearrangements, rearrangementChain, 0)
    isRevTail = isRevSegment(rearrangements, rearrangementChain, -1)
    if isRevHead and isRevTail:
        return [(i, 1 - j) for i, j in reversed(rearrangementChain)]
    return rearrangementChain

def derivedSeqSegments(rearrangements, rearrangementChain):
    for i, x in enumerate(rearrangementChain):
        segments = orientedSegments(rearrangements, x)
        for j, y in enumerate(segments):
            if j == 0:
                if i == 0:
                    seg = y
                else:
                    seg = seg[0:2] + y[2:3]
            else:
                yield seg
                seg = y
    yield seg

def derivedSeqFromChain(rearrangements, rearrangementChain):
    c = flippedRearrangementChain(rearrangements, rearrangementChain)
    segments = list(derivedSeqSegments(rearrangements, c))
    return c, segments

def derivedSeqSortKey(derivedSeq):
    rearrangementChain, segments = derivedSeq
    return sorted(segments)  # xxx ???

def namedDerivedSeqs(derivedSeqs, edges):
    for i, x in enumerate(derivedSeqs):
        rearrangementChain, segments = x
        name = "der" + str(i + 1)
        if isCircular(rearrangementChain, edges):
            name += ":CIRCULAR"
        yield name, rearrangementChain, segments

def showRearrangementOrder(rearrangements, derivedSeq):
    name, rearrangementChain, segments = derivedSeq
    print(name, file=sys.stderr)
    for rearrangementNum, isFlipped in rearrangementChain:
        name, segments = rearrangements[rearrangementNum]
        strand = "+-"[isFlipped]
        print(strand, name, file=sys.stderr)
    print(file=sys.stderr)

def segmentGroups(segments, maxLen):
    stub = maxLen // 3  # xxx ???
    group = []
    for seg in segments:
        chrom, beg, end = seg
        if abs(end - beg) <= maxLen:
            group.append(seg)
        else:
            if beg < end:
                e = beg + stub
                b = end - stub
            else:
                e = beg - stub
                b = end + stub
            group.append((chrom, beg, e))
            yield group
            group = [(chrom, b, end)]
    yield group

def derivedSeqParts(opts, derivedSeqs):
    for i, x in enumerate(derivedSeqs):
        name, rearrangementChain, segments = x
        sg = list(segmentGroups(segments, int(opts.maxlen)))
        for j, y in enumerate(sg):
            n = name if len(sg) == 1 else name + string.ascii_lowercase[j]
            yield n, y

def isNear(derivedSeqPartX, derivedSeqPartY, maxDist):
    nameX, segmentsX = derivedSeqPartX
    nameY, segmentsY = derivedSeqPartY
    for chrX, begX, endX in segmentsX:
        bX = min(begX, endX)
        eX = max(begX, endX)
        for chrY, begY, endY in segmentsY:
            bY = min(begY, endY)
            eY = max(begY, endY)
            if chrX == chrY and eX + maxDist >= bY and eY + maxDist >= bX:
                return True
    return False

def adjacencyListOfDerivedSeqParts(derParts, maxDist):
    for i, x in enumerate(derParts):
        e = enumerate(derParts)
        yield [j for j, y in e if i != j and isNear(x, y, maxDist)]

def doOneEdgeSet(opts, rearrangements, wayNum, edges):
    chains = rearrangementChainsFromEdges(len(rearrangements), edges)
    derSeqs = [derivedSeqFromChain(rearrangements, i) for i in chains]
    derSeqs.sort(key=derivedSeqSortKey)
    derSeqs = list(namedDerivedSeqs(derSeqs, edges))

    if opts.verbose:
        for i in derSeqs:
            showRearrangementOrder(rearrangements, i)

    derParts = list(derivedSeqParts(opts, derSeqs))
    adjList = list(adjacencyListOfDerivedSeqParts(derParts, int(opts.maxlen)))
    components = connectedComponents(adjList)

    for i, x in enumerate(components):
        partName = str(i + 1)
        if opts.all:
            partName = str(wayNum) + "-" + partName
        print("# PART", partName)
        print()
        for derPartNum in x:
            name, segments = derParts[derPartNum]
            print(name)
            for i in segments:
                printSegment(*i)
            print()

def main(opts, args):
    logging.basicConfig(format="%(filename)s: WARNING: %(message)s")
    rearrangements = list(rearrangementsFromLines(openFile(args[0])))
    if opts.groups:
        rearrangements = list(wantedRearrangements(opts, rearrangements))
    nodes = sorted(nodesFromRearrangements(rearrangements))

    edgeSetsPerChrom = []
    for chrom, group in itertools.groupby(nodes, itemgetter(0)):
        group = list(group)
        numOfMatchings = numOfMaximumMatchings(group)
        if numOfMatchings > 1:
            warning = "{0} equally-good ways of linking ends in {1}"
            logging.warning(warning.format(numOfMatchings, chrom))
        if opts.verbose:
            showNodesOfOneChromosome(rearrangements, group)
        if opts.all:
            edgeSets = list(allEdgeSetsFromNodes(group))
        else:
            edgeSets = [list(edgesFromNodes(group))]
        edgeSetsPerChrom.append(edgeSets)

    for i, x in enumerate(itertools.product(*edgeSetsPerChrom)):
        edges = list(itertools.chain.from_iterable(x))
        doOneEdgeSet(opts, rearrangements, i + 1, edges)

if __name__ == "__main__":
    usage = "%prog [options] rearrangements-file > linked.txt"
    descr = "Infer links between rearranged sequences, and reconstruct derived chromosomes."
    op = optparse.OptionParser(usage=usage, description=descr)
    op.add_option("-a", "--all", action="store_true",
                  help="show all equally-good ways of linking")
    op.add_option("-g", "--groups", metavar="LIST",
                  help="only use these groups")
    op.add_option("-m", "--maxlen", type="float", metavar="M", default=1000000,
                  help="separate rearrangements > M bp apart "
                  "(default=%default)")
    op.add_option("-v", "--verbose", action="count", default=0,
                  help="show more details")
    opts, args = op.parse_args()
    if len(args) == 1:
        main(opts, args)
    else:
        op.print_help()
