#!/usr/bin/env python

import os
import sys
import re
import shutil
import uuid
import subprocess
import argparse
import logging

LOGLEVELS = ['INFO','WARN','DEBUG','ERROR']
suffixes = ["bed","bim","fam"]

def parseArguments():
    parser = argparse.ArgumentParser(description='Merge a set of plink files.')
    parser.add_argument('bases', type=str, nargs='+',help='source bases')
    parser.add_argument('--log', dest='log', action='store',
                   default = "pmerge.log", help='log file (def: pmerge.log)')
    parser.add_argument('--log-level', dest='loglevel', action='store',
                   default = "INFO", help='log level (def: INFO)')
    parser.add_argument('-d','--dest', dest='dest', action='store',
                   required=True, help='output name')
    args = parser.parse_args()
    if args.loglevel not in LOGLEVELS:
       sys.exit("Unknown log level <%s>: choose from %s"%\
                (args.loglevel,LOGLEVELS))
    logging.basicConfig(level=logging.getLevelName(args.loglevel),
                    format='%(asctime)s %(levelname)s %(message)s',
                    filename=args.log,
                    filemode='a')
    for suf in suffixes:
       dest="%s.%s"%(args.dest,suf)
       if os.path.exists(dest):
          logging.error("Destination file <%s> already exists"%dest)
          print("Destination file <%s> already exists"%dest)
          sys.exit(18)
    return args



args=parseArguments()




dest = args.dest

session=uuid.uuid1().hex[0:7]

merge_dst  = "/tmp/mrg"+session
merge_miss = "%s-merge.missnp"%merge_dst
missf      = "/tmp/miss-"+session
flipped    = "/tmp/flip-"+session
snps       = "/tmp/snps-"+session


# Command to merge files
mc = "plink --bfile %(source)s --bmerge %(merge)s.bed %(merge)s.bim %(merge)s.fam --make-bed --out %(out)s"

# Command to flip
flip = "plink --bfile %s --flip %s --make-bed --out %s"

# Command to remove
remove = "plink --bfile %s --exclude %s --make-bed --out %s"


def move_plink(x,y):
   for suf in suffixes:
      src="%s.%s"%(x,suf)
      dst="%s.%s"%(y,suf)
      logging.debug('  rename %s to %s'%(src,dst)  )
      shutil.move(src,dst)

def show_num_snps(sofar,msg):
    num=subprocess.check_output("wc -l %s"%sofar,shell=True).rstrip()
    logging.info(msg%(num,sofar))

def do_cmd(x):
   logging.debug(x)
   os.system(x)

def merge_file(sofar,newf):
    logging.info("\nOpening %s\n"%newf)
    show_num_snps("%s.bim"%sofar,"There are %s SNPs in %ss\n")
    # try do to the plink merge
    do_cmd(mc%({'source':sofar,'merge':newf,'out':merge_dst}))
    num="0"
    # Do we need to flip?
    if os.path.exists(merge_miss):
       # move the flip file to a new place
       show_num_snps(merge_miss,"Flipping %s SNPs in %s\n")
       shutil.move(merge_miss,missf)
       # do the flip and replace the sofar file
       do_cmd(flip%(sofar,missf,flipped))
       move_plink(flipped,sofar)
       # try to do the plink merge again
       do_cmd(mc%({'source':sofar,'merge':newf,'out':merge_dst}))
       # If that failed, will try to remove failed snps
       if os.path.exists(merge_miss):
          num=subprocess.check_output("wc -l %s"%merge_miss,shell=True).rstrip()
          logging.info("Removing %s SNPs in %s\n"%(num,sofar))
          move_plink(sofar,temp_plink)
          do_cmd(remove%(temp_plink,merge_miss,sofar))
          show_num_snps("%s.bim"%sofar,"There are %s SNPs in %s\n")
          do_cmd(mc%({'source':sofar,'merge':newf,'out':merge_dst}))
    if not( os.path.exists("%s.bed"%merge_dst) and \
            os.path.exists("%s.bim"%merge_dst) and \
            os.path.exists("%s.fam"%merge_dst)):
        logging.error("Missing one of the %s.{bed,bim,fam} files\n")
        sys.exit(19)
    show_num_snps("%s.bim"%newf,"There are %s SNPs in %s\n")
    move_plink(temp,sofar)
   

def cleanFiles(orig):
   # remove duplicate SNPs
   logging.info("Cleaning <%s>"+str(orig))
   clean=[]
   for fn in orig:
      logging.info("Cleaning %s"%fn)
      logging.debug("cut -f 2 %s.bim | sort | uniq -d > %s"%(fn,snps))
      subprocess.check_output("cut -f 2 %s.bim | sort | uniq -d > %s"%\
                              (fn,snps),shell=True)
      curr = os.path.join("ptemp",fn)
      os.system(remove%(fn,snps,curr))
      show_num_snps(snps,"%s SNPs removed %s\n")
      clean.append("ptemp/%s"%fn)
   return clean

if not os.path.exists("ptemp"): os.mkdir("ptemp")

clean = cleanFiles(args.bases)

for suffix in suffixes:
  shutil.copyfile("%s.%s"%(clean[0],suffix),"%s.%s"%(dest,suffix))

for fn in clean[1:]:
    merge_file(dest,fn)

errf.close()

#os.system("/bin/rm -rf ptemp/*")
#os.system("/bin/rm -rf %s.* %s.* %s.*"%(temp,missf,temp_plink))
    
