#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 14:17:12 2018

script to find the annotation of a region from syri's output

@author: goel
"""
import os
import argparse
import sys

def readData(args, cwdPath):
    annoCoords = pd.DataFrame()
    for f in args.f:
        fin = f.name
        try:
            finD = pd.read_table(cwdPath+fin, header = None, dtype = object)
        except pd.io.common.EmptyDataError:
            print(fin, "Out.txt is empty. Skipping analysing it.")
            continue
        except Exception as e:
            print("ERROR: while trying to read ", fin, "Out.txt", e)
            continue
        annoData = finD.loc[finD[0] == "#"].copy().drop([0, 4], axis=1).dropna(axis=1, how="all")
        annoData["state"] = fin
        annoData = annoData.reindex(columns=["state"]+annoData.columns[:-1].tolist())
        # annoCoords = annoCoords.append(annoData.copy(), sort=True)
        annoCoords = pd.concat([annoCoords, annoData], sort=True)
    annoCoords[[2, 3, 6, 7]] = annoCoords[[2, 3, 6, 7]].astype("int")
    colnames = ["SVtype", "refChr", "refStart", "refEnd", "qryChr", "qryStart", "qryEnd", "ctxType", "Genome"]
    annoCoords.columns = colnames[:annoCoords.shape[1]]
    annoCoords.sort_values(["refChr", "refStart", "refEnd", "qryChr", "qryStart", "qryEnd"], inplace=True)
    annoCoords.index = range(len(annoCoords))
    return annoCoords
# END


def getRegion(args, cwdPath):
    data = readData(args, cwdPath)
    if not args.q:
        regions = data.loc[(data.refChr == args.chr) & (data.refStart <= args.end) & (data.refEnd >= args.start)].copy()
    elif args.q:
        regions = data.loc[(data.qryChr == args.chr) & (data.qryStart <= args.end) & (data.qryEnd >= args.start)].copy()
        regions.sort_values(["qryChr", "qryStart", "qryEnd", "refChr", "refStart", "refEnd"], inplace=True	)
    print(regions.to_csv(sep="\t", index=False, header=False))
# END


def getBed(args, cwdPath):
    data = readData(args, cwdPath)
    mar = args.m
    if not args.q:
        outD = data[["refChr","refStart","refEnd"]].copy()
        outD.refStart = outD.refStart - 1 - mar
        outD.loc[outD.refStart < 0, "refStart"] = 0
        outD.refEnd = outD.refEnd - 1 + mar
    elif args.q:
        outD = data[["qryChr","qryStart","qryEnd"]].copy()
        outD.qryStart = outD.qryStart - 1 - mar
        outD.loc[outD.qryStart < 0, "qryStart"] = 0
        outD.qryEnd = outD.qryEnd - 1 + mar
    if args.a:
        outD[["SVType","ctxType","genome"]] = data[["SVtype","ctxType","Genome"]]
    print(outD.to_csv(sep="\t", index=False, header=False))
# END

    
def getSnpAnno(args, cwdPath):
    fin = pd.read_table(args.fin, header = None)
    fin.columns = ["chr","pos"]
    fin.sort_values(["chr","pos"], inplace = True)
    data = readData(args, cwdPath)
    outData = deque()
    chroms  = deque()
    poses = deque()
    if not args.q:
        chromo = ""
        for row in fin.itertuples():
            if chromo != row.chr:
                chrData = data.loc[(data.refChr== row.chr)].copy()
                chromo = row.chr
            regions = chrData.loc[(chrData.refStart <= row.pos) & (chrData.refEnd >= row.pos)].copy()
            chroms.extend([row.chr]*len(regions))
            poses.extend([row.pos]*len(regions))
            outData.append(outData.append(regions))
    elif args.q:
        chromo = ""
        for row in fin.itertuples():
            if chromo != row.chr:
                chrData = data.loc[(data.qryChr== row.chr)].copy()
                chromo = row.chr
            regions = chrData.loc[(data.qryStart <= row.pos) & (data.qryEnd >= row.pos)].copy()
            chroms.extend([row.chr]*len(regions))
            poses.extend([row.pos]*len(regions))
            outData.append(outData.append(regions))
    outData = pd.concat(list(outData))
    outData["inChr"] = list(chroms)
    outData["inPos"] = list(poses)
    print(outData.to_csv(sep="\t", index=False, header=False), end="")
# END


def plotref(refidx, syrireg1, syrireg2, varpos, figout, bw=100000):
    """
    :param refbed: path to ref.faidx
    :param syrireg: syri regions in BEDPE format
    :param varpos: variant positions in BED
    :param figout: output figure path
    :return:
    """
    from matplotlib import pyplot as plt
    from matplotlib.patches import Rectangle
    from matplotlib.collections import PatchCollection
    from syri.scripts.func import readfaidxbed, mergeRanges
    import pybedtools as bt # TODO: remove this dependency
    from collections import deque, defaultdict
    import numpy as np
    import logging

    logger = logging.getLogger('plotref')

    def _readbed(vp, refbed):
        _chrs = set([r[0] for r in refbed])
        bincnt = defaultdict(deque)
        skipchrs = []
        curchr = ''
        pos = deque()
        added_chrs = list()
        with open(vp, 'r') as fin:
            for line in fin:
                line = line.strip().split()
                if len(line) < 3:
                    logger.warning("Incomplete information in BED file at line: {}. Skipping it.".format("\t".join(line)))
                    continue
                if line[0] not in _chrs:
                    if line[0] not in skipchrs:
                        logger.info("Chromosome in BED is not present in FASTA or not selected for plotting. Skipping it. BED line: {}".format("\t".join(line)))
                        skipchrs.append(line[0])
                    continue
                if curchr == '':
                    curchr = line[0]
                    pos.append([int(line[1]), int(line[2])])
                elif curchr == line[0]:
                    pos.append([int(line[1]), int(line[2])])
                else:
                    if line[0] in added_chrs:
                        logger.error("BED file: {} is not sorted. For plotting tracks, sorted bed file is required. Exiting.".format(vp))
                        sys.exit()
                    if len(pos) > 1:
                        rngs = mergeRanges(np.array(pos))
                    else:
                        rngs = pos
                    chrpos = np.array(list(set([i for r in rngs for i in range(r[0], r[1])])))
                    # Get bin breakpoints for the chromosome
                    bins = np.concatenate((np.arange(0, [r[2] for r in refbed if r[0] == curchr][0], bw), np.array([r[2] for r in refbed if r[0] == curchr])), axis=0)
                    binval = np.histogram(chrpos, bins)[0]
                    bincnt[curchr] = deque([((bins[i] + bins[i+1])/2, binval[i]/bw) for i in range(len(binval))])
                    added_chrs.append(curchr)
                    # Set the new chromosome
                    curchr = line[0]
                    pos = deque([[int(line[1]), int(line[2])]])
            if len(pos) > 1:
                rngs = mergeRanges(np.array(pos))
            else:
                rngs = pos
            chrpos = np.array(list(set([i for r in rngs for i in range(r[0], r[1])])))
            # Get bin breakpoints for the chromosome
            bins = np.concatenate((np.arange(0, [r[2] for r in refbed if r[0] == curchr][0], bw), np.array([r[2] for r in refbed if r[0] == curchr])), axis=0)
            binval = np.histogram(chrpos, bins)[0]
            bincnt[curchr] = deque([((bins[i] + bins[i+1])/2, binval[i]/bw) for i in range(len(binval))])
        return bincnt
    # END

    chr_height = 0.3
    th = 0.6    # Track height for plotting bedfile data

    refbed = readfaidxbed(refidx)
    # TODO: remove dependency on pybedtools
    syrioutbed = bt.BedTool('/srv/netscratch/dep_mercier/grp_schneeberger/projects/read_vs_assembly/results/human/hg002/reads15kb/svcalls/syri/hg002.reads15kb.winnowmap.hap1syri.bedpe').sort()


    bedbin = _readbed(varpos, refbed)
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_ylim([0, len(refbed)+1])
    ax.set_xlim([0, max([r[2] for r in refbed])])
    colors = {'SYN': 'lightgrey', 'INV': '#FFA500', 'TRANS': '#9ACD32', 'DUP': '#00BBFF'}
    for i in range(len(refbed)):
        patches = deque()
        r = refbed[i]
        y = len(refbed) - i
        ax.add_patch(Rectangle((r[1], y), r[2], chr_height, fill=False, linewidth=0.5))
        # ax.add_patch(Rectangle((r[1], y), r[2], chr_height, linewidth=0.5, color='black'))
        bed = syrioutbed.filter(lambda b: b.chrom == r[0]).saveas()
        for b in bed:
            patches.append(Rectangle((b.start, y), b.stop-b.start, chr_height, color=colors[b[6]], linewidth=0))
        ax.add_collection(PatchCollection(patches, match_original=True))

        chrpos = [k[0] for k in bedbin[r[0]]]
        tpos = [k[1] for k in bedbin[r[0]]]
        tposmax = max(tpos)
        y0 = len(refbed) - i + chr_height
        ypos = [(t*th/tposmax)+y0 for t in tpos]
        ax.fill_between(chrpos, ypos, y0, color='blue', lw=0.1, zorder=2)

    return
# END

    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    
    parser_region = subparsers.add_parser("region", help="get annotation for the region")
    parser_getbed = subparsers.add_parser("getbed", help="get bed file for the regions")
    parser_snpanno = subparsers.add_parser("snpanno", help="get annotation for SNPs list")
    parser_syri2bedpe = subparsers.add_parser("syri2bedpe", help="Converts the syri.out to BEDPE format")

    if len(sys.argv[1:]) == 0:
        parser.print_help()
        sys.exit()

    parser_syri2bedpe.set_defaults(func=syri2bedpe)
    parser_syri2bedpe.add_argument('input', help="syri.out file", type=argparse.FileType('r'))
    parser_syri2bedpe.add_argument('output', help="output BEDPE file", type=argparse.FileType('w'))
    parser_syri2bedpe.add_argument('-f', help="Annotation to filter out. Use multiple times to filter more than one annotation types.", type=str, action='append')


    parser_region.set_defaults(func=getRegion)
    parser_region.add_argument("chr", help="Chromosome ID", type=str)
    parser_region.add_argument("start", help="region start", type=int)
    parser_region.add_argument("end", help="region end", type=int)
    parser_region.add_argument("-p", help="prefix", type=str)
    parser_region.add_argument("-q", help="search region in query genome", action="store_true", default=False)
    parser_region.add_argument("-f", help="files to search", type=argparse.FileType('r'), nargs="+", default=None)
    
    parser_getbed.set_defaults(func=getBed)
    parser_getbed.add_argument("-q", help="search region in query genome", action="store_true", default=False)
    parser_getbed.add_argument("-f", help="files to search", type=argparse.FileType('r'), nargs="+", default=None)
    parser_getbed.add_argument("-m", help="margin to add on the regions", type=int, default=0)
    parser_getbed.add_argument("-a", help="output region annotation", action="store_true", default=False)
    
    parser_snpanno.set_defaults(func=getSnpAnno)
    parser_snpanno.add_argument("fin", help="input file in tab separated format. Column 1: Chr. Column 2: position", type=argparse.FileType('r'), default=None)
    parser_snpanno.add_argument("-q", help="search annotation in query genome", action="store_true", default=False)
    parser_snpanno.add_argument("-f", help="files to search. by default all files will be searched", type=argparse.FileType('r'), nargs="+", default=None)
       
    
    cwdPath = os.getcwd()+os.sep
    args = parser.parse_args()
    if "f" in args:
        if args.f == None:
            args.f = list(map(open, ['synOut.txt', 'invOut.txt', 'TLOut.txt', 'invTLOut.txt', 'dupOut.txt', 'invDupOut.txt', 'ctxOut.txt']))
    
    import pandas as pd
    from collections import deque
    args.func(args, cwdPath)
    
