#!/usr/bin/env python3

# -*- coding: utf-8 -*-
"""
Created on Wed May 10 13:05:51 2017

@author: goel

Current VCF handling tools seems to be cater VCFs generated from read alignment.
VCF generated from assemblies, have other requirement.
This package consists of methods to manipulate VCFs generated from the comparison
of haplotype-resolved assemblies
"""

import argparse
import sys


def setlogconfig(lg):
    # TODO: Setup one logging function for syri and other utilities in func.py
    import logging.config
    logging.config.dictConfig({
        'version': 1,
        'disable_existing_loggers': False,
        'formatters': {
            'log_file': {
                'format': "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s",
            },
            'stdout': {
                'format': "%(name)s - %(asctime)s - %(levelname)s - %(message)s",
            },
        },
        'handlers': {
            'stdout': {
                'class': 'logging.StreamHandler',
                'formatter': 'stdout',
                'level': lg,
            },
        },
        'loggers': {
            '': {
                'level': lg,
                'handlers': ['stdout'],
                # 'handlers': ['stdout', 'log_file'],
            },
        },
    })
# END


def hapmerge(args):
    """
    Merges VCF files from different haplotypes of a sample.
    Current limitations:
        1) Only merges two VCFs (i.e. for a diplid genome)
        2) Only works with SNPs and INDELs
    :param args: Input arguments
    :return: VCF (also BCF?) file where the SNPs/INDELs are merged with GT tag. Var IDs and INFO fields would be deleted. FORMAT and sample would be added. Positions with more than 1 SNP/indels in any VCF would be filtered out in both VCFs.
    """
    from pysam import VariantFile
    import logging
    from collections import deque, defaultdict, Counter

    def readvcf(v):
        """
        Read VCF file
        :param v: VariantFile object
        :return: VCF data for unique and list of duplicated positions
        """
        cids = deque()      # Chromosome ids
        lastchr = ''
        vd = defaultdict(deque)
        for var in v.fetch():
            if var.chrom != lastchr:
                if var.chrom not in cids:
                    cids.append(var.chrom)
            vd[var.chrom].append((var.pos, var.ref.upper(), tuple(map(str.upper, var.alts))))                  # Convert bases to upper case to allow correct comparison
        gp = deque()            # garb pos
        v1d = defaultdict(dict)
        for k in vd:
            vd[k] = deque(set(vd[k]))
            garb = Counter([v[0] for v in vd[k]])
            garb = set([k for k, v in garb.items() if v > 1])
            for g in garb:
                gp.append((k, g))
            n = {v[0]: (v[1:]) for v in vd[k] if v[0] not in garb}     # new var list
            v1d[k] = n
        return v1d, gp, cids

    # TODO: Perform the same operation on the syri.out file as well
    # vcf1: Input VCF (also BCF?) for haplotype 1
    # vcf2: Input VCF (also BCF?) for haplotype 2
    # sample: Sample name to be added to VCF
    vcf1 = args.vcf1.name
    vcf2 = args.vcf2.name
    outvcf = args.out.name
    sample = args.sample
    setlogconfig(args.log)
    logger = logging.getLogger("hapmerge")
    v1 = VariantFile(vcf1, 'r')
    v2 = VariantFile(vcf2, 'r')
    v1dict, garbv1, chridsv1 = readvcf(v1)
    v2dict, garbv2, chridsv2 = readvcf(v2)
    if str(v1.header) != str(v2.header):
        logger.info(f"Input VCFs have different headers. Using the header from {vcf1} in the output file.")
    header = v1.header
    # header.formats.add(id='GT', number='1', type='String', description="Genotype after merging haplotype calls")
    print('l1')
    header.samples.add(sample)
    # Read first vcf, save chromosome order, position and variations
    print('l2')
    print('l3')
    # TODO: Add IDs as well (can be added as a semicolon separated list)

    # Remove garb positions
    garbv1.extend(garbv2)
    for v in garbv1:
        try: v1dict[v[0]].pop(v[1])
        except KeyError: pass
        try: v2dict[v[0]].pop(v[1])
        except KeyError: pass
    vo = VariantFile(outvcf, 'w', header=header)
    # Iterate over chromosomes and print SVs corresponding to each chromosome

    for k in chridsv1:
        chrom = k
        pos = sorted(set(v1dict[k]).union(v2dict[k]))
        for p in pos:
            ref1, ref2, alt1, alt2 = ['']*4
            try:
                ref1 = v1dict[k][p][0]
                alt1 = v1dict[k][p][1][0]
            except KeyError: pass
            try:
                ref2 = v2dict[k][p][0]
                alt2 = v2dict[k][p][1][0]
            except KeyError: pass
            if ref1 == '' and ref2 == '':
                logger.error(f"Both ref are empty for position {k}:{p}. Error. Skipping this position.")
                continue
            elif ref1 == '':
                # Heterozygous mutations present only in paternal chromosome
                ref, alt, gt = ref2, alt2, (0, 1)
            elif ref2 == '':
                # Heterozygous mutations present only in maternal chromosome
                ref, alt, gt = ref1, alt1, (1, 0)
            elif ref1 == ref2:
                ref = ref1
                # Homozygous mutation
                if alt1 == alt2:
                    alt = alt1
                    gt = (1, 1)
                # Multi-allelic variation involving only SNPs, insertions
                else:
                    alt = (alt1, alt2)
                    gt = (1, 2)
            else:
                # Mutli-allelic variation involving deletions
                loge = f"Incompatible reference alleles at {k}:{p}: ref1-{ref1}, ref2-{ref2}. Skipping it."
                if len(ref1) == len(ref2):
                    logger.error(loge)
                    continue
                if len(ref1) > len(ref2):
                    if ref1[:len(ref2)] != ref2:
                        logger.error(loge)
                        continue
                    ref, alt = ref1, (alt1, alt2+ref1[len(ref2):])
                else:
                    if ref1 != ref2[:len(ref1)]:
                        logger.error(loge)
                        continue
                    ref, alt = ref2, (alt1+ref2[len(ref1):], alt2)
                gt = (1, 2)
            alleles = (ref, alt) if type(alt) is not tuple else (ref, alt[0], alt[1])
            r = vo.new_record(contig=chrom, start=p-1, stop=p+len(ref)-1, alleles=alleles, id='.', filter='PASS')
            r.samples[sample]['GT'] = gt
            vo.write(r)
    vo.close()
    logger.info(f'Finished merging. Output saved in {outvcf}')
# END


def vcfshv(args):
    """
    Takes a VCF file and filter SNPs, insertions, deletions.
    :return:
    """
    # :param vcf: Input VCF file
    # :param outvcf: Output VCF file
    import logging
    from pysam import VariantFile
    vcf1 = args.vcf1.name
    outvcf = args.out.name
    setlogconfig(args.log)
    logger = logging.getLogger("vcfshv")
    v1 = VariantFile(vcf1, 'r')
    header = v1.header
    # header.formats.add(id='GT', number=1, type='String', description="Genotype after merging haplotype calls")
    # header.samples.add(sample)
    passlist = set([f'<{c}>' for c in v1.header.alts.keys() if c not in {'SNP', 'INS', 'DEL'}])
    try:
        vo = VariantFile(outvcf, 'w', header=header)
        vo.close()
    except Exception as e:
        logger.error(f"Error in opening output VCF file. {e}")
    # TODO: Figure out how to write directly to the output VariantFile object
    with open(outvcf, 'a') as vo:
        for var in v1.fetch():
            if var.alts[0] not in passlist:
                vo.write(str(var))
    logger.info(f'Finished filtering SNPs/indels. Output saved in {outvcf}')
    return 0
# END


def filindsize(args):
    """
    Filter indels in a VCF file based on size
    :param args: Input arguments
    :return: VCF file filtered with indels within the selected range.
    """
    # :param vcf: Input VCF file
    # :param outvcf: Output VCF file
    # :param imin: Minimum indel size
    # :param imax: Maximum indel size
    import logging
    from pysam import VariantFile
    vcf1 = args.vcf1.name
    outvcf = args.out.name
    imin = args.min
    imax = args.max
    setlogconfig(args.log)
    logger = logging.getLogger("filindsize")
    v1 = VariantFile(vcf1, 'r')
    header = v1.header
    try:
        vo = VariantFile(outvcf, 'w', header=header)
        vo.close()
    except Exception as e:
        logger.error(f"Error in opening output VCF file. {e}")
    # TODO: Figure out how to write directly to the output VariantFile object
    with open(outvcf, 'a') as vo:
        for var in v1.fetch():
            if len(var.ref) == 1 and len(var.alts) == 1 and len(var.alts[0]) == 1:
                vo.write(str(var))
            elif not imin <= len(var.ref) <= imax:
                continue
            elif any([not imin <= len(a) <= imax for a in var.alts]):
                continue
            else:
                vo.write(str(var))
    logger.info(f'Finished filindsize SNPs/indels. Output saved in {outvcf}')
    return 0


# END

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparsers = parser.add_subparsers()
    # Define subparsers
    parser_vcfshv = subparsers.add_parser("vcfshv", help="Filter VCF file from syri to select SNPs and indels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser_hapmerge = subparsers.add_parser("hapmerge", help="Merge VCF files corresponding two variantions in the haploid assemblies of a diploid organism. VCF should only have SNPs and indels before merging.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser_filindsize = subparsers.add_parser("filindsize", help="Filter indels in a VCF file based on size of ref/alt allele. Currently, can only handle VCF having only SNPs and indels.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # Check whether any command is selected ####################################
    if len(sys.argv[1:]) == 0:
        parser.print_help()
        sys.exit()
    # Set parameters for vcfshv ################################################
    parser_vcfshv.set_defaults(func=vcfshv)
    parser_vcfshv.add_argument("vcf1", help='Input VCF file', type=argparse.FileType('r'))
    parser_vcfshv.add_argument("out", help='Output file name', type=argparse.FileType('w'))
    # Set parameters for hapmerge ##############################################
    parser_hapmerge.set_defaults(func=hapmerge)
    parser_hapmerge.add_argument("vcf1", help='Input VCF file where haplotype 1 is used as query against the reference', type=argparse.FileType('r'))
    parser_hapmerge.add_argument("vcf2", help='Input VCF file where haplotype 2 is used as query against the reference', type=argparse.FileType('r'))
    parser_hapmerge.add_argument("out", help='Output file name', type=argparse.FileType('w'))
    parser_hapmerge.add_argument("sample", help='Sample name to be added in output VCF file', type=str, default='sample')
    # Set parameters for hapmerge
    parser_filindsize.set_defaults(func=filindsize)
    parser_filindsize.add_argument("vcf1", help='Input VCF file where haplotype 1 is used as query against the reference', type=argparse.FileType('r'))
    parser_filindsize.add_argument("out", help='Output file name', type=argparse.FileType('w'))
    parser_filindsize.add_argument("--min", help='Minimum indel size to select', type=int, default=1)
    parser_filindsize.add_argument("--max", help='Maximum indel size to select', type=int, default=100)

    # Common arguments
    parser.add_argument("--log", dest="log", help="log level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARN"])


    args = parser.parse_args()
    # Run the functions
    args.func(args)

