#!/usr/bin/env python

''' 
 * All rights Reserved, Designed By HIT-Bioinformatics   
 * @Title: cuteSV 
 * @author: tjiang & sqcao & zdzhang
 * @date: Nov. 29th 2023
 * @version V2.1.0
'''

import pysam
import cigar
from cuteSV.cuteSV_Description import parseArgs
from multiprocessing import Pool,Manager,Queue, current_process
from cuteSV.CommandRunner import *
# from resolution_type import * 
from cuteSV.cuteSV_resolveINV import run_inv
from cuteSV.cuteSV_resolveTRA import run_tra
from cuteSV.cuteSV_resolveINDEL import run_ins, run_del
from cuteSV.cuteSV_resolveDUP import run_dup
from cuteSV.cuteSV_genotype import generate_output, generate_pvcf, load_valuable_chr, load_bed, Generation_VCF_header
from cuteSV.cuteSV_forcecalling import force_calling_chrom
import os
import argparse
import logging
import sys
import time
import gc
import pickle
import atexit

dic_starnd = {1: '+', 2: '-'}
RefChangeOp=set([0,2,7,8])
signal = {1 << 2: 0, \
            1 >> 1: 1, \
            1 << 4: 2, \
            1 << 11: 3, \
            1 << 4 | 1 << 11: 4}
'''
    1 >> 1 means normal_foward read
    1 << 2 means unmapped read
    1 << 4 means reverse_complement read
    1 << 11 means supplementary alignment read
    1 << 4 | 1 << 11 means supplementary alignment with reverse_complement read
'''
def detect_flag(Flag):
    back_sig = signal[Flag] if Flag in signal else 0
    return back_sig

def analysis_inv(ele_1, ele_2, read_name, candidate, SV_size):
    if ele_1[5] == '+':
        # +-
        if ele_1[3] - ele_2[3] >= SV_size:
            if ele_2[0] + 0.5 * (ele_1[3] - ele_2[3]) >= ele_1[1]:
                candidate.append(("++", 
                                    ele_2[3], 
                                    ele_1[3], 
                                    read_name,
                                    "INV",
                                    ele_1[4]))
                # head-to-head
                # 5'->5'
        if ele_2[3] - ele_1[3] >= SV_size:
            if ele_2[0] + 0.5 * (ele_2[3] - ele_1[3]) >= ele_1[1]:
                candidate.append(("++", 
                                    ele_1[3], 
                                    ele_2[3], 
                                    read_name,
                                    "INV",
                                    ele_1[4]))
                # head-to-head
                # 5'->5'
    else:
        # -+
        if ele_2[2] - ele_1[2] >= SV_size:
            if ele_2[0] + 0.5 * (ele_2[2] - ele_1[2]) >= ele_1[1]:
                candidate.append(("--", 
                                    ele_1[2], 
                                    ele_2[2], 
                                    read_name,
                                    "INV",
                                    ele_1[4]))
                # tail-to-tail
                # 3'->3'
        if ele_1[2] - ele_2[2] >= SV_size:
            if ele_2[0] + 0.5 * (ele_1[2] - ele_2[2]) >= ele_1[1]:
                candidate.append(("--", 
                                    ele_2[2], 
                                    ele_1[2], 
                                    read_name,
                                    "INV",
                                    ele_1[4]))
                # tail-to-tail
                # 3'->3'


def analysis_bnd(ele_1, ele_2, read_name, candidate):
    '''
    *********Description*********
    *	TYPE A:		N[chr:pos[	*
    *	TYPE B:		N]chr:pos]	*
    *	TYPE C:		[chr:pos[N	*
    *	TYPE D:		]chr:pos]N	*
    *****************************
    '''
    if ele_2[0] - ele_1[1] <= 100:
        if ele_1[5] == '+':
            if ele_2[5] == '+':
                # +&+
                if ele_1[4] < ele_2[4]:
                    candidate.append(('A', 
                                        ele_1[3], 
                                        ele_2[4], 
                                        ele_2[2], 
                                        read_name,
                                        "TRA",
                                        ele_1[4]))
                    # N[chr:pos[
                else:
                    candidate.append(('D', 
                                        ele_2[2], 
                                        ele_1[4], 
                                        ele_1[3], 
                                        read_name,
                                        "TRA",
                                        ele_2[4]))
                    # ]chr:pos]N
            else:
                # +&-
                if ele_1[4] < ele_2[4]:
                    candidate.append(('B', 
                                        ele_1[3], 
                                        ele_2[4], 
                                        ele_2[3], 
                                        read_name,
                                        "TRA",
                                        ele_1[4]))
                    # N]chr:pos]
                else:
                    candidate.append(('B', 
                                        ele_2[3], 
                                        ele_1[4], 
                                        ele_1[3], 
                                        read_name,
                                        "TRA",
                                        ele_2[4]))
                    # N]chr:pos]
        else:
            if ele_2[5] == '+':
                # -&+
                if ele_1[4] < ele_2[4]:
                    candidate.append(('C', 
                                        ele_1[2], 
                                        ele_2[4], 
                                        ele_2[2], 
                                        read_name,
                                        "TRA",
                                        ele_1[4]))
                    # [chr:pos[N
                else:
                    candidate.append(('C', 
                                        ele_2[2], 
                                        ele_1[4], 
                                        ele_1[2], 
                                        read_name,
                                        "TRA",
                                        ele_2[4]))
                    # [chr:pos[N
            else:
                # -&-
                if ele_1[4] < ele_2[4]:
                    candidate.append(('D', 
                                        ele_1[2], 
                                        ele_2[4], 
                                        ele_2[3], 
                                        read_name,
                                        "TRA",
                                        ele_1[4]))
                    # ]chr:pos]N
                else:
                    candidate.append(('A', 
                                        ele_2[3], 
                                        ele_1[4], 
                                        ele_1[2], 
                                        read_name,
                                        "TRA",
                                        ele_2[4]))
                    # N[chr:pos[

def analysis_split_read(split_read, SV_size, RLength, read_name, candidate, MaxSize, query):
    '''
    read_start	read_end	ref_start	ref_end	chr	strand
    #0			#1			#2			#3		#4	#5
    '''
    SP_list = sorted(split_read, key = lambda x:x[0])

    # detect INS involoved in a translocation
    trigger_INS_TRA = 0	

    # Store Strands of INV

    if len(SP_list) == 2:
        ele_1 = SP_list[0]
        ele_2 = SP_list[1]
        if ele_1[4] == ele_2[4]:
            if ele_1[5] != ele_2[5]:
                analysis_inv(ele_1, 
                                ele_2, 
                                read_name, 
                                candidate["INV"],
                                SV_size)

            else:
                # dup & ins & del 
                a = 0
                if ele_1[5] == '-':
                    ele_1 = [RLength-SP_list[a+1][1], RLength-SP_list[a+1][0]]+SP_list[a+1][2:]
                    ele_2 = [RLength-SP_list[a][1], RLength-SP_list[a][0]]+SP_list[a][2:]
                    query = query[::-1]

                if ele_1[3] - ele_2[2] >= SV_size:
                    # if ele_2[1] - ele_1[1] >= ele_1[3] - ele_2[2]:
                    if ele_2[0] - ele_1[1] >= ele_1[3] - ele_2[2]:
                        candidate["INS"].append(((ele_1[3]+ele_2[2])/2, 
                                        ele_2[0]+ele_1[3]-ele_2[2]-ele_1[1], 
                                        read_name,
                                        str(query[ele_1[1]+int((ele_1[3]-ele_2[2])/2):ele_2[0]-int((ele_1[3]-ele_2[2])/2)]),
                                        "INS",
                                        ele_2[4]))
                    else:
                        candidate["DUP"].append((ele_2[2], 
                                            ele_1[3], 
                                            read_name,
                                            "DUP",
                                            ele_2[4]))

                delta_length = ele_2[0] + ele_1[3] - ele_2[2] - ele_1[1]
                if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                    if ele_2[2] - ele_1[3] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                        candidate["INS"].append(((ele_2[2]+ele_1[3])/2, 
                                            delta_length, 
                                            read_name,
                                            str(query[ele_1[1]+int((ele_2[2]-ele_1[3])/2):ele_2[0]-int((ele_2[2]-ele_1[3])/2)]),
                                            "INS",
                                            ele_2[4]))
                delta_length = ele_2[2] - ele_2[0] + ele_1[1] - ele_1[3]
                if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                    if ele_2[0] - ele_1[1] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                        candidate["DEL"].append((ele_1[3], 
                                            delta_length, 
                                            read_name,
                                            "DEL",
                                            ele_2[4]))
        else:
            trigger_INS_TRA = 1
            analysis_bnd(ele_1, ele_2, read_name, candidate["TRA"])

    else:
        # over three splits
        for a in range(len(SP_list[1:-1])):
            ele_1 = SP_list[a]
            ele_2 = SP_list[a+1]
            ele_3 = SP_list[a+2]

            if ele_1[4] == ele_2[4]:
                if ele_2[4] == ele_3[4]:
                    if ele_1[5] == ele_3[5] and ele_1[5] != ele_2[5]:
                        if ele_2[5] == '-':
                            # +-+
                            if ele_2[0] + 0.5 * (ele_3[2] - ele_1[3]) >= ele_1[1] and ele_3[0] + 0.5 * (ele_3[2] - ele_1[3]) >= ele_2[1]:
                                # No overlaps in split reads

                                if ele_2[2] >= ele_1[3] and ele_3[2] >= ele_2[3]:
                                    candidate["INV"].append(("++", 
                                                        ele_1[3], 
                                                        ele_2[3], 
                                                        read_name,
                                                        "INV",
                                                        ele_1[4]))
                                    # head-to-head
                                    # 5'->5'
                                    candidate["INV"].append(("--", 
                                                        ele_2[2], 
                                                        ele_3[2], 
                                                        read_name,
                                                        "INV",
                                                        ele_1[4]))
                                    # tail-to-tail
                                    # 3'->3'
                        else:
                            # -+-
                            if ele_1[1] <= ele_2[0] + 0.5 * (ele_1[2] - ele_3[3]) and ele_3[0] + 0.5 * (ele_1[2] - ele_3[3]) >= ele_2[1]:
                                # No overlaps in split reads

                                if ele_2[2] - ele_3[3] >= -50 and ele_1[2] - ele_2[3] >= -50:
                                    candidate["INV"].append(("++", 
                                                        ele_3[3], 
                                                        ele_2[3], 
                                                        read_name,
                                                        "INV",
                                                        ele_1[4]))
                                    # head-to-head
                                    # 5'->5'
                                    candidate["INV"].append(("--", 
                                                        ele_2[2], 
                                                        ele_1[2], 
                                                        read_name,
                                                        "INV",
                                                        ele_1[4]))
                                    # tail-to-tail
                                    # 3'->3'	

                    if len(SP_list) - 3 == a:
                        if ele_1[5] != ele_3[5]:
                            if ele_2[5] == ele_1[5]:
                                # ++-/--+
                                analysis_inv(ele_2, 
                                                ele_3, 
                                                read_name, 
                                                candidate["INV"], 
                                                SV_size)
                            else:
                                # +--/-++
                                analysis_inv(ele_1, 
                                                ele_2, 
                                                read_name, 
                                                candidate["INV"], 
                                                SV_size)

                    if ele_1[5] == ele_3[5] and ele_1[5] == ele_2[5]:
                        # dup & ins & del 
                        if ele_1[5] == '-':
                            ele_1 = [RLength-SP_list[a+2][1], RLength-SP_list[a+2][0]]+SP_list[a+2][2:]
                            ele_2 = [RLength-SP_list[a+1][1], RLength-SP_list[a+1][0]]+SP_list[a+1][2:]
                            ele_3 = [RLength-SP_list[a][1], RLength-SP_list[a][0]]+SP_list[a][2:]
                            query = query[::-1]

                        if ele_2[3] - ele_3[2] >= SV_size and ele_2[2] < ele_3[3]:
                            candidate["DUP"].append((ele_3[2], 
                                                ele_2[3], 
                                                read_name,
                                                "DUP",
                                                ele_2[4]))

                        if a == 0:
                            if ele_1[3] - ele_2[2] >= SV_size:
                                candidate["DUP"].append((ele_2[2], 
                                                    ele_1[3], 
                                                    read_name,
                                                    "DUP",
                                                    ele_2[4]))

                        delta_length = ele_2[0] + ele_1[3] - ele_2[2] - ele_1[1]
                        if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                            if ele_2[2] - ele_1[3] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                if ele_3[2] >= ele_2[3]:
                                    candidate["INS"].append(((ele_2[2]+ele_1[3])/2, 
                                                        delta_length, 
                                                        read_name,
                                                        str(query[ele_1[1]+int((ele_2[2]-ele_1[3])/2):ele_2[0]-int((ele_2[2]-ele_1[3])/2)]),
                                                        "INS",
                                                        ele_2[4]))
                        delta_length = ele_2[2] - ele_2[0] + ele_1[1] - ele_1[3]
                        if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                            if ele_2[0] - ele_1[1] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                if ele_3[2] >= ele_2[3]:
                                    candidate["DEL"].append((ele_1[3], 
                                                        delta_length, 
                                                        read_name,
                                                        "DEL",
                                                        ele_2[4]))
                        
                        if len(SP_list) - 3 == a:
                            ele_1 = ele_2
                            ele_2 = ele_3

                            delta_length = ele_2[0] + ele_1[3] - ele_2[2] - ele_1[1]
                            if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                                if ele_2[2] - ele_1[3] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                    candidate["INS"].append(((ele_2[2]+ele_1[3])/2, 
                                                        delta_length, 
                                                        read_name,
                                                        str(query[ele_1[1]+int((ele_2[2]-ele_1[3])/2):ele_2[0]-int((ele_2[2]-ele_1[3])/2)]),
                                                        "INS",
                                                        ele_2[4]))

                            delta_length = ele_2[2] - ele_2[0] + ele_1[1] - ele_1[3]
                            if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and ele_2[2] - ele_2[0] + ele_1[1] - ele_1[3] >= SV_size:
                                if ele_2[0] - ele_1[1] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                    candidate["DEL"].append((ele_1[3], 
                                                        delta_length, 
                                                        read_name,
                                                        "DEL",
                                                        ele_2[4]))

                    if len(SP_list) - 3 == a and ele_1[5] != ele_2[5] and ele_2[5] == ele_3[5]:
                        ele_1 = ele_2
                        ele_2 = ele_3
                        ele_3 = None
                    if ele_3 == None or (ele_1[5] == ele_2[5] and ele_2[5] != ele_3[5]):
                        if ele_1[5] == '-':
                            ele_1 = [RLength-SP_list[a+1][1], RLength-SP_list[a+1][0]]+SP_list[a+1][2:]
                            ele_2 = [RLength-SP_list[a][1], RLength-SP_list[a][0]]+SP_list[a][2:]
                            query = query[::-1]
                        delta_length = ele_2[0] + ele_1[3] - ele_2[2] - ele_1[1]
                        if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                            if ele_2[2] - ele_1[3] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                candidate["INS"].append(((ele_2[2]+ele_1[3])/2, 
                                                    delta_length, 
                                                    read_name,
                                                    str(query[ele_1[1]+int((ele_2[2]-ele_1[3])/2):ele_2[0]-int((ele_2[2]-ele_1[3])/2)]),
                                                    "INS",
                                                    ele_2[4]))

                        delta_length = ele_2[2] - ele_2[0] + ele_1[1] - ele_1[3]
                        if ele_1[3] - ele_2[2] < max(SV_size, delta_length/5) and delta_length >= SV_size:
                            if ele_2[0] - ele_1[1] <= max(100, delta_length/5) and (delta_length <= MaxSize or MaxSize == -1):
                                candidate["DEL"].append((ele_1[3], 
                                                    delta_length, 
                                                    read_name,
                                                    "DEL",
                                                    ele_2[4]))

            else:
                trigger_INS_TRA = 1
                analysis_bnd(ele_1, ele_2, read_name, candidate["TRA"])

                if len(SP_list) - 3 == a:
                    if ele_2[4] != ele_3[4]:
                        analysis_bnd(ele_2, ele_3, read_name, candidate["TRA"])

    if len(SP_list) >= 3 and trigger_INS_TRA == 1:
        if SP_list[0][4] == SP_list[-1][4]:
            if SP_list[0][5] != SP_list[-1][5]:
                pass
            else:
                if SP_list[0][5] == '+':
                    ele_1 = SP_list[0]
                    ele_2 = SP_list[-1]
                else:
                    ele_1 = [RLength-SP_list[-1][1], RLength-SP_list[-1][0]]+SP_list[-1][2:]
                    ele_2 = [RLength-SP_list[0][1],RLength-SP_list[0][0]]+SP_list[0][2:]
                    query = query[::-1]
                dis_ref = ele_2[2] - ele_1[3]
                dis_read = ele_2[0] - ele_1[1]
                if dis_ref < 100 and dis_read - dis_ref >= SV_size and (dis_read - dis_ref <= MaxSize or MaxSize == -1):
                    candidate["INS"].append((min(ele_2[2], ele_1[3]), 
                                        dis_read - dis_ref, 
                                        read_name,
                                        str(query[ele_1[1]+int(dis_ref/2):ele_2[0]-int(dis_ref/2)]),
                                        "INS",
                                        ele_2[4]))	

                if dis_ref <= -SV_size:
                    candidate["DUP"].append((ele_2[2], 
                                        ele_1[3], 
                                        read_name,
                                        "DUP",
                                        ele_2[4]))

def acquire_clip_pos(deal_cigar):
    seq = list(cigar.Cigar(deal_cigar).items())
    if seq[0][1] == 'S':
        first_pos = seq[0][0]
    else:
        first_pos = 0
    if seq[-1][1] == 'S':
        last_pos = seq[-1][0]
    else:
        last_pos = 0

    bias = 0
    for i in seq:
        if i[1] == 'M' or i[1] == 'D' or i[1] == '=' or i[1] == 'X':
            bias += i[0]
    return [first_pos, last_pos, bias]

def organize_split_signal(primary_info, Supplementary_info, total_L, SV_size, 
    min_mapq, max_split_parts, read_name, candidate, MaxSize, query):
    split_read = list()
    if len(primary_info) > 0:
        split_read.append(primary_info)
        min_mapq = 0
    for i in Supplementary_info:
        seq = i.split(',')
        local_chr = seq[0]
        local_start = int(seq[1])
        local_cigar = seq[3]
        local_strand = seq[2]
        local_mapq = int(seq[4])
        if local_mapq >= min_mapq:
            local_set = acquire_clip_pos(local_cigar)
            if local_strand == '+':
                 split_read.append([local_set[0], total_L-local_set[1], local_start, 
                     local_start+local_set[2], local_chr, local_strand])
            else:
                try:
                    split_read.append([local_set[1], total_L-local_set[0], local_start, 
                        local_start+local_set[2], local_chr, local_strand])
                except:
                    pass
    if len(split_read) <= max_split_parts or max_split_parts == -1:
        analysis_split_read(split_read, SV_size, total_L, read_name, candidate, MaxSize, query)

def generate_combine_sigs(sigs, Chr_name, read_name, svtype, candidate, merge_dis):
    if len(sigs) == 0:
        pass
    elif len(sigs) == 1:
        if svtype == 'INS':
            candidate.append((sigs[0][0], 
                                            sigs[0][1], 
                                            read_name,
                                            sigs[0][2],
                                            svtype,
                                            Chr_name))
        else:
            candidate.append((sigs[0][0], 
                                            sigs[0][1], 
                                            read_name,
                                            svtype,
                                            Chr_name))
    else:
        temp_sig = sigs[0]
        if svtype == "INS":
            temp_sig += [sigs[0][0]]
            for i in sigs[1:]:
                if i[0] - temp_sig[3] <= merge_dis:
                    temp_sig[1] += i[1]
                    temp_sig[2] += i[2]
                    temp_sig[3] = i[0]
                else:
                    candidate.append((temp_sig[0], 
                                                        temp_sig[1], 
                                                        read_name,
                                                        temp_sig[2],
                                                        svtype,
                                                        Chr_name))
                    temp_sig = i
                    temp_sig.append(i[0])
            candidate.append((temp_sig[0], 
                                                temp_sig[1], 
                                                read_name,
                                                temp_sig[2],
                                                svtype,
                                                Chr_name))
        else:
            temp_sig += [sum(sigs[0])]
            # merge_dis_bias = max([i[1]] for i in sigs)
            for i in sigs[1:]:
                if i[0] - temp_sig[2] <= merge_dis:
                    temp_sig[1] += i[1]
                    temp_sig[2] = sum(i)
                else: 
                    candidate.append((temp_sig[0], 
                                                        temp_sig[1], 
                                                        read_name,
                                                        svtype,
                                                        Chr_name))
                    temp_sig = i
                    temp_sig.append(i[0])
            candidate.append((temp_sig[0], 
                                                temp_sig[1], 
                                                read_name,
                                                svtype,
                                                Chr_name))

OPLIST=[
    pysam.CBACK,
    pysam.CDEL,
    pysam.CDIFF,
    pysam.CEQUAL,
    pysam.CHARD_CLIP,
    pysam.CINS,
    pysam.CMATCH,
    pysam.CPAD,
    pysam.CREF_SKIP,
    pysam.CSOFT_CLIP
]
RefChangeOp=set([0,2,7,8])

#QUERY CHANGE, REF CHANGE
CHANGETABLE={
    pysam.CMATCH:     (True,True),
    pysam.CINS:       (True,False),
    pysam.CDEL:       (False,True),
    pysam.CREF_SKIP:  (False,True),
    pysam.CPAD:       (False,False),
    pysam.CEQUAL:     (True,True),
    pysam.CDIFF:      (True,True)
}

CHANGEOP=[CHANGETABLE[i] if i in CHANGETABLE.keys() else (False,False) for i in range(max(OPLIST)+1)]
REFCHANGEOP=[CHANGETABLE[i][1] if i in CHANGETABLE.keys() else False for i in range(max(OPLIST)+1)]
INDELOP=[(i==pysam.CDEL or i==pysam.CINS) for i in range(max(OPLIST)+1)]

def parse_read(read, candidate, Chr_name, SV_size, min_mapq, max_split_parts, min_read_len, min_siglength, merge_del_threshold, merge_ins_threshold, MaxSize):
    if read.query_length < min_read_len:
        return []
    Combine_sig_in_same_read_ins = list()
    Combine_sig_in_same_read_del = list()

    #new start
    process_signal = detect_flag(read.flag)
    if read.mapq >= min_mapq:
        pos_start = read.reference_start # 0-based
        pos_end = read.reference_end
        sig_start=pos_start
        softclip_left = 0
        softclip_right = 0
        hardclip_left = 0
        hardclip_right = 0
        shift_ins_read = 0
        if read.cigar[0][0] == 4:
            softclip_left = read.cigar[0][1]
        elif read.cigar[0][0] == 5:
            hardclip_left = read.cigar[0][1]
        
        shift_ins_read=-hardclip_left
        for op, oplen in read.cigartuples:
            # calculate offset of an ins sig in read
            if op != 2:#might be fixed later
                shift_ins_read += oplen
            if oplen >= min_siglength and INDELOP[op]:
                if op==2:
                    Combine_sig_in_same_read_del.append([sig_start, oplen])
                    sig_start += oplen
                else:
                    Combine_sig_in_same_read_ins.append([sig_start, oplen,
                        str(read.query_sequence[shift_ins_read-oplen:shift_ins_read])])
            else:
                # if op in RefChangeOp:
                if REFCHANGEOP[op]:
                    sig_start += oplen

        
        if read.cigar[-1][0] == 4:
            softclip_right = read.cigar[-1][1]
        elif read.cigar[-1][0] == 5:
            hardclip_right = read.cigar[-1][1]

        if hardclip_left != 0:
            softclip_left = hardclip_left
        if hardclip_right != 0:
            softclip_right = hardclip_right

    # ************Combine signals in same read********************
    generate_combine_sigs(Combine_sig_in_same_read_ins, Chr_name, read.query_name, "INS", candidate["INS"], merge_ins_threshold)
    generate_combine_sigs(Combine_sig_in_same_read_del, Chr_name, read.query_name, "DEL", candidate["DEL"], merge_del_threshold)

    if process_signal == 1 or process_signal == 2: # 0 / 16
        Tags = read.get_tags()
        if read.mapq >= min_mapq:
            if process_signal == 1:
                primary_info = [softclip_left, read.query_length-softclip_right, pos_start, 
                pos_end, Chr_name, dic_starnd[process_signal]]
            else:
                primary_info = [softclip_right, read.query_length-softclip_left, pos_start, 
                pos_end, Chr_name, dic_starnd[process_signal]]
        else:
            primary_info = []

        for i in Tags:
            if i[0] == 'SA':
                Supplementary_info = i[1].split(';')[:-1]
                organize_split_signal(primary_info, Supplementary_info, read.query_length, 
                    SV_size, min_mapq, max_split_parts, read.query_name, candidate, MaxSize, read.query_sequence)
    return candidate

def init_reading_process(sam_path):
    global samfile
    samfile=pysam.AlignmentFile(sam_path)

def cleanup():
    global samfile
    if samfile!=None:
        samfile.close()
        samfile=None

SVTYPES=["DEL", "INS", "DUP", "INV", "TRA"]
samfile=None
def single_pipe(sam_path, min_length, min_mapq, max_split_parts, min_read_len, temp_dir, 
                task, min_siglength, merge_del_threshold, merge_ins_threshold, MaxSize, bed_regions):
    candidate = {}
    candidate["DEL"]=list()
    candidate["INS"]=list()
    candidate["DUP"]=list()
    candidate["INV"]=list()
    candidate["TRA"]=list()
    reads_info_list= list()
    Chr_name = task[0]
    global samfile

    for read in samfile.fetch(Chr_name, task[1], task[2]):
        # handle_read(read, task[1], bed_regions, candidate, reads_info_list, Chr_name, min_length, min_mapq, max_split_parts, min_read_len, min_siglength, merge_del_threshold, merge_ins_threshold, MaxSize)
        if read.flag == 256 or read.flag == 272:
            continue
        pos_start = read.reference_start # 0-based
        pos_end = read.reference_end
        in_bed = False
        if bed_regions != None:
            for bed_region in bed_regions:
                if pos_end <= bed_region[0] or pos_start >= bed_region[1]:
                    continue
                else:
                    in_bed = True
                    break
        else:
            in_bed = True
        if read.reference_start >= task[1] and in_bed:
            parse_read(read, candidate, Chr_name, min_length, min_mapq, max_split_parts, 
                                min_read_len, min_siglength, merge_del_threshold, 
                                merge_ins_threshold, MaxSize)
            if read.mapq >= min_mapq:
                is_primary = 0
                if read.flag in [0, 16]:
                    is_primary = 1
                reads_info_list.append((pos_start, pos_end, is_primary, read.query_name, Chr_name))
    pid=current_process().pid
    for sv_type in SVTYPES:
        with open("%ssignatures/%s%s.pickle"%(temp_dir,pid,sv_type),"ab") as f:
            pickle.dump(candidate[sv_type],f)
    with open("%ssignatures/%sreads.pickle"%(temp_dir,pid),"ab") as f:
        pickle.dump(reads_info_list,f)
    logging.info("Finished %s:%d-%d."%(Chr_name, task[1], task[2]))	
    gc.collect()
    # return (candidate, reads_info_list)
    return None

def multi_run_wrapper(args):
    # logging.info(args)
    return single_pipe(*args)

#old_file_sig[]=mem_sig[DEL: -2,-1,0,1,2, INS: -2,-1,0,1,2,3, DUP: -2,-1,0,1,2, INV: -2,-1,0,1,2,3, TRA: -2,-1,0,1,2,3,4, reads: -1,0,1,2,3]
def process_process_sigs_type(args):
    sv_type, temporary_dir, pids, write_old_sigs=args
    #read
    type_candidates=[]
    for pid in pids:
        with open("%ssignatures/%s%s.pickle"%(temporary_dir,pid,sv_type), "rb") as f:
            while True:
                try:
                    candidate=pickle.load(f)
                    type_candidates.extend(candidate)
                except EOFError:
                    break
    #write
    if sv_type=="DEL":
        type_candidates.sort(key=lambda x: (x[-1], int(x[0]), x[1], x[2]))
        type_candidates=remove_duplicates_sorted(type_candidates)
        if write_old_sigs:
            with open("%s/DEL.sigs"%(temporary_dir),"w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2])
                        print(line, end="",file=f)

    elif sv_type=="INS":
        type_candidates.sort(key=lambda x: (x[-1], int(x[0]), x[1], x[2], x[3]))
        type_candidates=remove_duplicates_sorted(type_candidates)
        if write_old_sigs:
            with open("%s/INS.sigs"%(temporary_dir),"w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%s\t%d\t%d\t%s\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3])
                        print(line, end="",file=f)
    elif sv_type=="DUP":
        type_candidates.sort(key=lambda x: (x[-1], int(x[0]), int(x[1]), x[2]))
        type_candidates=remove_duplicates_sorted(type_candidates)
        if write_old_sigs:
            with open("%s/DUP.sigs"%(temporary_dir),"w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2])
                        print(line, end="",file=f)
    elif sv_type=="INV":
        type_candidates.sort(key=lambda x: (x[-1], x[0], int(x[1]), x[2], x[3]))
        type_candidates=remove_duplicates_sorted(type_candidates)
        if write_old_sigs:
            with open("%s/INV.sigs"%(temporary_dir),"w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3])
                        print(line, end="",file=f)
    elif sv_type=="TRA":
        type_candidates.sort(key=lambda x: (x[-1], x[2], x[0], int(x[1]), x[3], x[4], x[5]))
        type_candidates=remove_duplicates_sorted(type_candidates)
        if write_old_sigs:
            with open("%s/TRA.sigs"%(temporary_dir),"w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%s\t%s\t%d\t%s\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3], ele[4])
                        print(line, end="",file=f)
    elif sv_type=="reads":
        type_candidates.sort(key=lambda x: (x[-1]))#reads file not deduped, might be resolved later
        if write_old_sigs:
            with open("%s/reads.sigs"%(temporary_dir), "w") as f:
                if len(type_candidates)!=0:
                    for ele in type_candidates:
                        line="%s\t%d\t%d\t%d\t%s\n"%(ele[-1], ele[0], ele[1], ele[2], ele[3])
                        print(line,end="",file=f)
    index={}
    reads_count={}
    with open("%s/%s.pickle"%(temporary_dir,sv_type),"wb") as f:
        chr=None
        startl=0
        start=0
        if sv_type=="reads":
            if len(type_candidates)!=0:
                for i in range(len(type_candidates)):
                    ele=type_candidates[i]
                    if ele[-1]!=chr:
                        if chr==None:
                            chr=ele[-1]
                        else:
                            dump=pickle.dumps(type_candidates[startl:i])
                            f.write(dump)
                            index[chr]=start
                            reads_count[chr]=i-startl
                            start+=len(dump)
                            chr=ele[-1]
                            startl=i
                pickle.dump(type_candidates[startl:],f)
                index[chr]=start
                reads_count[chr]=len(type_candidates)-startl
        else:
            if len(type_candidates)!=0:
                for i in range(len(type_candidates)):
                    ele=type_candidates[i]
                    if ele[-1]!=chr:
                        if chr==None:
                            chr=ele[-1]
                        else:
                            dump=pickle.dumps(type_candidates[startl:i])
                            f.write(dump)
                            index[chr]=start
                            start+=len(dump)
                            chr=ele[-1]
                            startl=i
                pickle.dump(type_candidates[startl:],f)
                index[chr]=start
    return (sv_type,index,reads_count)

def write_sigs(temporary_dir, candidates, reads_info_list, prefix=""):
    index={}
    with open("%s/%sDEL.sigs"%(temporary_dir,prefix),"w") as f:
        index["DEL"]={}
        chr=None
        start=0
        bytecount=0
        if len(candidates["DEL"])!=0:
            for ele in candidates["DEL"]:
                line="%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2])
                if ele[-1]!=chr:
                    if chr==None:
                        chr=ele[-1]
                    else:
                        index["DEL"][chr]=start
                        chr=ele[-1]
                        start=bytecount
                bytecount+=len(line)
                print(line, end="",file=f)
            index["DEL"][chr]=start
    with open("%s/%sINS.sigs"%(temporary_dir,prefix),"w") as f:
        index["INS"]={}
        chr=None
        start=0
        bytecount=0
        if len(candidates["INS"])!=0:
            for ele in candidates["INS"]:
                line="%s\t%s\t%d\t%d\t%s\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3])
                if ele[-1]!=chr:
                    if chr==None:
                        chr=ele[-1]
                    else:
                        index["INS"][chr]=start
                        chr=ele[-1]
                        start=bytecount
                bytecount+=len(line)
                print(line, end="",file=f)
            index["INS"][chr]=start
    with open("%s/%sDUP.sigs"%(temporary_dir,prefix),"w") as f:
        index["DUP"]={}
        chr=None
        start=0
        bytecount=0
        if len(candidates["DUP"])!=0:
            for ele in candidates["DUP"]:
                line="%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2])
                if ele[-1]!=chr:
                    if chr==None:
                        chr=ele[-1]
                    else:
                        index["DUP"][chr]=start
                        chr=ele[-1]
                        start=bytecount
                bytecount+=len(line)
                print(line, end="",file=f)
            index["DUP"][chr]=start
    with open("%s/%sINV.sigs"%(temporary_dir,prefix),"w") as f:
        index["INV"]={}
        chr=None
        start=0
        bytecount=0
        if len(candidates["INV"])!=0:
            for ele in candidates["INV"]:
                line="%s\t%s\t%s\t%d\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3])
                if ele[-1]!=chr:
                    if chr==None:
                        chr=ele[-1]
                    else:
                        index["INV"][chr]=start
                        chr=ele[-1]
                        start=bytecount
                bytecount+=len(line)
                print(line, end="",file=f)
            index["INV"][chr]=start
    with open("%s/%sTRA.sigs"%(temporary_dir,prefix),"w") as f:
        index["TRA"]={}
        chr=None
        start=0
        bytecount=0
        if len(candidates["TRA"])!=0:
            for ele in candidates["TRA"]:
                line="%s\t%s\t%s\t%d\t%s\t%d\t%s\n"%(ele[-2], ele[-1], ele[0], ele[1], ele[2], ele[3], ele[4])
                if ele[-1]!=chr:
                    if chr==None:
                        chr=ele[-1]
                    else:
                        index["TRA"][chr]=start
                        chr=ele[-1]
                        start=bytecount
                bytecount+=len(line)
                print(line, end="",file=f)
            index["TRA"][chr]=start
    with open("%s/%sreads.sigs"%(temporary_dir,prefix), "w") as f:
        for ele in reads_info_list:
            print("%s\t%d\t%d\t%d\t%s\n"%(ele[-1], ele[0], ele[1], ele[2], ele[3]), end="", file=f)
    with open("%s/sigindex.pickle"%temporary_dir,"wb") as f:
        pickle.dump(index,f)
    return index

def remove_duplicates_sorted(sorted):
    if len(sorted) == 0:
        return []
    i=0
    j=0
    while i < len(sorted):
        if sorted[i] != sorted[j]:
            j += 1
            sorted[j] = sorted[i]
        i += 1
    
    return sorted[:j+1]

def remove_duplicates(data):
    seen = set()
    result = []
    for item in data:
        if item not in seen:
            result.append(item)
            seen.add(item)
    
    return result

def add_candidates(result):
    global candidates
    candidate=result[0][0]
    candidates["DEL"].extend(candidate["DEL"])
    candidates["INS"].extend(candidate["INS"])
    candidates["DUP"].extend(candidate["DUP"])
    candidates["INV"].extend(candidate["INV"])
    candidates["TRA"].extend(candidate["TRA"])
    candidates["reads_info"].extend(result[0][1])

# candidates={}
def main_ctrl(args, argv):
    if args.work_dir[-1] == '/':
        temporary_dir = args.work_dir
    else:
        temporary_dir = args.work_dir+'/'
    # check the temporary files
    if not os.path.isfile(args.reference):
        raise FileNotFoundError("[Errno 2] No such file: '%s'"%args.reference)
    if not os.path.exists(args.work_dir):
        raise FileNotFoundError("[Errno 2] No such directory: '%s'"%args.work_dir)
    for item in SVTYPES:
        if os.path.exists(temporary_dir + item + '.sigs'):
            raise FileExistsError("[Errno 2] File exists: '%s'"%(temporary_dir+item+'.sigs'))
        if os.path.exists(temporary_dir + item + '.pickle'):
            raise FileExistsError("[Errno 2] File exists: '%s'"%(temporary_dir+item+'.pickle'))

    samfile = pysam.AlignmentFile(args.input)
    contig_num = len(samfile.get_index_statistics())
    logging.info("The total number of chromsomes: %d"%(contig_num))

    Task_list = list()
    chr_name_list = list()
    contigINFO = list()

    ref_ = samfile.get_index_statistics()
    total_mapped=0
    for i in ref_:
        total_mapped+=i[1]
    mapped_unit=total_mapped/args.threads/10
    for i in ref_:
        chr_name_list.append(i[0])
        local_ref_len = samfile.get_reference_length(i[0])
        contigINFO.append([i[0], local_ref_len])
        if total_mapped==0 or i[1]<=mapped_unit:
            batch_size=args.batches
        else:
            batch_size=local_ref_len/(int(i[1]/mapped_unit)+1)
        if local_ref_len < batch_size:
            Task_list.append([i[0], 0, local_ref_len])
        else:
            pos = 0
            task_round = int(local_ref_len/batch_size)
            for j in range(task_round):
                Task_list.append([i[0], pos, pos+batch_size])
                pos += batch_size
            if pos < local_ref_len:
                Task_list.append([i[0], pos, local_ref_len])
    bed_regions = load_bed(args.include_bed, Task_list)
    #'''
    
    candidates={}
    candidates["DEL"]=list()
    candidates["INS"]=list()
    candidates["DUP"]=list()
    candidates["INV"]=list()
    candidates["TRA"]=list()
    reads_info_list=list()
    candidates["reads_info"]=reads_info_list
    
    atexit.register(cleanup)
    analysis_pools = Pool(processes=int(args.threads), initializer=init_reading_process, initargs=(args.input,))
    os.mkdir("%ssignatures"%temporary_dir)
    results=[]#use this is faster than make a long paras list
    for i in range(len(Task_list)):
        paras = [(args.input, 
                    args.min_size, 
                    args.min_mapq, 
                    args.max_split_parts, 
                    args.min_read_len, 
                    temporary_dir, 
                    Task_list[i], 
                    args.min_siglength, 
                    args.merge_del_threshold, 
                    args.merge_ins_threshold, 
                    args.max_size,
                    None if bed_regions == None else bed_regions[i])]
        analysis_pools.map_async(multi_run_wrapper, paras)
    pids = [process.pid for process in analysis_pools._pool]
    analysis_pools.close()
    analysis_pools.join()
    #'''
    logging.info("Rebuilding signatures of structural variants.")

    analysis_pools = Pool(processes=int(args.threads))
    paras=[]
    for sv_type in SVTYPES:
        paras.append((sv_type,temporary_dir,pids,args.write_old_sigs))
    paras.append(("reads",temporary_dir,pids,args.write_old_sigs))
    results=analysis_pools.map_async(process_process_sigs_type, paras)
    analysis_pools.close()
    analysis_pools.join()
    sigs_index={}
    for r in results.get():
        if r!=None:
            sigs_index[r[0]]=r[1]
            if r[0]=="reads":
                sigs_index["reads_count"]=r[2]
    with open("%s/sigindex.pickle"%temporary_dir,"wb") as f:
        pickle.dump(sigs_index,f)
    del reads_info_list
    del results
    del candidates
    gc.collect()
    logging.info("Rebuilding signatures completed.")
    #'''
    result = list()

    if args.Ivcf != None:
        # force calling
        max_cluster_bias_dict = dict()
        max_cluster_bias_dict['INS'] = args.max_cluster_bias_INS
        max_cluster_bias_dict['DEL'] = args.max_cluster_bias_DEL
        max_cluster_bias_dict['DUP'] = args.max_cluster_bias_DUP
        max_cluster_bias_dict['INV'] = args.max_cluster_bias_INV
        max_cluster_bias_dict['TRA'] = args.max_cluster_bias_TRA
        threshold_gloab_dict = dict()
        threshold_gloab_dict['INS'] = args.diff_ratio_merging_INS
        threshold_gloab_dict['DEL'] = args.diff_ratio_merging_DEL
        
        result = force_calling_chrom(args.Ivcf, temporary_dir,
                         max_cluster_bias_dict, threshold_gloab_dict, args.gt_round, args.read_range, args.threads, sigs_index)

    else:

        logging.info("Clustering structural variants.")
        analysis_pools = Pool(processes=int(args.threads))

        # +++++DEL+++++
        for chr in sigs_index["DEL"]:
            para = [(temporary_dir, 
                    chr, 
                    "DEL", 
                    args.min_support,
                    args.diff_ratio_merging_DEL, 
                    args.max_cluster_bias_DEL, 
                    # args.diff_ratio_filtering_DEL, 
                    min(args.min_support, 5), 
                    args.input, 
                    args.genotype,
                    args.gt_round,
                    args.remain_reads_ratio,
                    sigs_index)]
            result.append(analysis_pools.map_async(run_del, para))

        # +++++INS+++++
        for chr in sigs_index["INS"]:
            para = [(temporary_dir, 
                    chr, 
                    "INS", 
                    args.min_support, 
                    args.diff_ratio_merging_INS, 
                    args.max_cluster_bias_INS, 
                    # args.diff_ratio_filtering_INS, 
                    min(args.min_support, 5), 
                    args.input, 
                    args.genotype,
                    args.gt_round,
                    args.remain_reads_ratio,
                    sigs_index)]
            result.append(analysis_pools.map_async(run_ins, para))

        # +++++INV+++++
        for chr in sigs_index["INV"]:
            para = [(temporary_dir, 
                    chr, 
                    "INV", 
                    args.min_support, 
                    args.max_cluster_bias_INV, 
                    args.min_size, 
                    args.input, 
                    args.genotype, 
                    args.max_size,
                    args.gt_round,
                    sigs_index)]
            result.append(analysis_pools.map_async(run_inv, para))

        # +++++DUP+++++
        for chr in sigs_index["DUP"]:
            para = [(temporary_dir, 
                    chr, 
                    args.min_support, 
                    args.max_cluster_bias_DUP,
                    args.min_size, 
                    args.input, 
                    args.genotype, 
                    args.max_size,
                    args.gt_round,
                    sigs_index)]
            result.append(analysis_pools.map_async(run_dup, para))

        # +++++TRA+++++
        for chr in sigs_index["TRA"]:
            para = [(temporary_dir, 
                    chr, 
                    args.min_support, 
                    args.diff_ratio_filtering_TRA, 
                    args.max_cluster_bias_TRA, 
                    args.input, 
                    args.genotype,
                    args.gt_round,
                    sigs_index)]
            result.append(analysis_pools.map_async(run_tra, para))

        analysis_pools.close()
        analysis_pools.join()

        results={}
        for res in result:
            try:
                chr, svs= res.get()[0]
                if chr not in results.keys():
                    results[chr]=[]
                results[chr].extend(svs)
            except:
                pass
        file = open(args.output, 'w')
        
    logging.info("Writing to your output file.")

    if args.Ivcf != None:
        file = open(args.output, 'w')
        Generation_VCF_header(file, contigINFO, args.sample, argv)
        file.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s\n"%(args.sample))
        analysis_pools = Pool(processes=int(args.threads))
        paras=[]
        for chrom in sorted(result.keys()):
            paras.append((args, result[chrom], args.reference, chrom))
        results=analysis_pools.starmap_async(generate_pvcf, paras)
        analysis_pools.close()
        analysis_pools.join()
        for r in results.get():
            for line in r:
                file.write(line)
        file.close()

    else:
        svid = dict()
        svid["INS"] = 0
        svid["DEL"] = 0
        svid["BND"] = 0
        svid["DUP"] = 0
        svid["INV"] = 0
        chroms=sorted(results.keys())
        os.mkdir("%sresults"%temporary_dir)
        Generation_VCF_header(file, contigINFO, args.sample, argv)
        file.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s\n"%(args.sample))
        analysis_pools = Pool(processes=int(args.threads))
        paras=[]
        for chrom in chroms:
            paras.append((args, results[chrom], args.reference, chrom, temporary_dir))
        results=analysis_pools.starmap_async(generate_output, paras)
        analysis_pools.close()
        analysis_pools.join()
        for chrom in chroms:
            with open("%sresults/%s.pickle"%(temporary_dir,chrom), "rb") as f:
                while True:
                    try:
                        lines = pickle.load(f)
                        for svtype, line in lines:
                            file.write(line.replace("<SVID>",str(svid[svtype])))
                            svid[svtype]+=1
                    except EOFError:
                        break
        file.close()

    if args.retain_work_dir:
        pass
    else:
        logging.info("Cleaning temporary files.")
        if args.Ivcf != None:
            cmd_remove_tempfile = ("rm -r %ssignatures %s*.sigs %s*.pickle"%(temporary_dir, temporary_dir, temporary_dir))
        else:
            cmd_remove_tempfile = ("rm -r %ssignatures %sresults %s*.sigs %s*.pickle"%(temporary_dir, temporary_dir, temporary_dir, temporary_dir))
        exe(cmd_remove_tempfile)
    samfile.close()

def setupLogging(debug=False):
    logLevel = logging.DEBUG if debug else logging.INFO
    logFormat = "%(asctime)s [%(levelname)s] %(message)s"
    logging.basicConfig( stream=sys.stderr, level=logLevel, format=logFormat )
    logging.info("Running %s" % " ".join(sys.argv))


def run(argv):
    args = parseArgs(argv)
    setupLogging(False)
    starttime = time.time()
    main_ctrl(args, argv)
    logging.info("Finished in %0.2f seconds."%(time.time() - starttime))

if __name__ == '__main__':
    run(sys.argv[1:])
