#!/usr/bin/env python

from __future__ import print_function


import argparse
import os
import re
import sys
import string
import math

#DEBUG
if len(sys.argv)==1:
    sys.argv="./stratexclude.py  --group ZUL --plink  --phe  ~/Dropbox/popsadmix/all2.phe --exclude-threshold 2 ~/Dropbox/popsadmix/pruned".split()
    

def setup():
    parser = argparse.ArgumentParser(description='Identify individuals who are different!')

    parser.add_argument('qfname', metavar='FNAME', type=str, help='Q file')
    parser.add_argument('--phe', dest='phe', action='store', default="",help='phenotype file (default empty)')
    parser.add_argument('--fam', dest='fam', action='store', default="",\
                        help='name of fam file (default None)')

    parser.add_argument('--clumpp', dest='clumpp',\
                        action='store_true', default=False,\
                        help='is the file in clumpp format (default None)')

    parser.add_argument('--smartpca', dest='smartpca',\
                        action='store_true', default=False,\
                        help='file in smartpca format ')

    parser.add_argument('--out', dest='out',\
                        action='store', default=sys.stdout,\
                        help='name of output file ')

    parser.add_argument('--plink', dest='plink',\
                        action='store_true', default=False,\
                        help='file in plink')

    parser.add_argument('--euclidean', dest='euclidean',\
                        action='store_true', default=False,\
                        help='use Euclidean distance to remove')

    
    parser.add_argument('--weighted', dest='weighted',\
                        action='store_true', default=False,\
                        help='weight distances with Eigenvalue')

    parser.add_argument('--phe-col', dest='phe_col', action='store',type=int,\
                        default=3,\
                        help='column in phenotype file (default 2)')

    parser.add_argument('--group', dest='group', action='store',required=True,\
                        help='which group')

    parser.add_argument('--analyse', dest='analyse', action='store_true',\
                        default=False,help='do analysis')

    parser.add_argument('--num-pcs', dest='numpcs', action='store',type=int,\
                       default=0,help='Number of PCs to use (0=all)')

    parser.add_argument('--numoutlieriter', dest='numoutlieriter', action='store',type=int,\
                       default=1,help='Number of iterations of outlier removal')


    parser.add_argument('--exclude-threshold', dest='exclude_threshold', 
                        action='store',type=float,\
                        default=0,help=' threshold for exclusion, default is 0 -- just show distances')


    args = parser.parse_args()
    args.phe_col = args.phe_col -1

    return args 


def getGroup(args):
    group = {}
    if args.phe:
        regexp = "([^\s]+)\s([^\s]+)"
        phef = open(args.phe)
        for pline in phef:
            data = pline.split()
            group["%s:%s"%(data[0],data[1])]= data[args.phe_col]
        phef.close()
    return group


def readFam(args,q,all_ids):
    id2n ={}
    n2id = []
    if args.fam:
        fname = args.qfname + ".fam" if args.plink else args.fam
        f = open(fname) 
        i = 0
        for line in f:
            mm = re.search("([^\s]+)\s([^\s]+).*",line)
            if not mm:
                sys.exit("illegal input in fam file <%s>\n"%line)
            the_id = "%s:%s"%(mm.group(1),mm.group(2))
            id2n[the_id] = i
            n2id.append(the_id)
            i=i+1
    else:
        if len(all_ids)>0:
            for n,id in enumerate(all_ids):
                n2id.append(id)
                id2n[id]=n
        else:
            n2id = [str(i) for i in range(len(q))]
    return (n2id,id2n)
        

def getEigVs(args):
   if args.weighted:
      if args.smartpca:
         eigens = f.readline()
         if not re.search("^\s*#",eigens):
            sys.exit("Can't find eigen value in line <%s>"%eigens)
         eigs = map(float,re.findall("([-\d.]+)",eigens))
      elif args.plink:
         with open(args.qfname+".eigenval") as eigvf:
             eigs = map(float,eigvf.readlines())
      else:
          sys.exit("Weighted only implemented for smartpca and plink data")
   else: eigs=False
   return eigs



def readQ(args):
    qfname = args.qfname + ".eigenvec" if args.plink else args.qfname 
    q = []
    eigs = []
    f = open(qfname)
    eigs = getEigVs(args)
    curr=0
    all_ids =[]
    for dline in f:
        if args.clumpp:
            mm = re.search(".*:\s+(.*)",dline)
            dline = mm.group(1)
        if args.smartpca:
            mm = re.search("\s*(\S+)+\s*(.*)",dline)
            curr_id = mm.group(1)
            dline = mm.group(2)
            data = re.findall("\S+",dline)
            all_ids.appedn(curr_id)            
        elif args.plink:
            tokens = dline.split()
            data = tokens[2:]
            curr_id = "%s:%s"%(tokens[0],tokens[1])
            all_ids.append(curr_id)
        else:
            data = re.findall("([\d.]+)",dline)
        N=len(data) if args.numpcs==0 else args.numpcs
        q.append(map(float,data[:N]))
        curr=curr+1
    f.close()
    if not eigs:
       eigs = [1]*N
    return (eigs,q,all_ids)


def pdiff(x,y):
    N = len(x)
    d=0.0
    for i in range(N):
        d=d+weights[i]*(x[i]-y[i])**2
    return math.sqrt(d/sum_weights)


def euclidRemove(q,args,mems):
    ng=len(mems)
    if ng == 0: 
       print(the_groups,"none")
       return
    p_ave = [0.0]*ng
    # compute centroid
    for id in mems:
        centroid =map(lambda (x,y): x+y, zip(q[id2n[id]],p_ave))
    centroid = map(lambda x: x/ng, p_ave)
    ave_dist = 0.0
    # Work out average distance
    dist = map(lambda x: pdiff(q[id2n[x]],centroid),mems)
    for d  in dist:
         ave_dist = ave_dist+d
    ave_dist = ave_dist/ng
    sn = 0.0
    # Work out standard deviation
    for  d in dist:
            sn = sn + (d-ave_dist)**2
    stdev = math.sqrt(sn/ng)
    print ("# ave_dist %f stdev %f"%(ave_dist,stdev))
    if args.exclude_threshold:
        for n,d in enumerate(dist):
            if  d >  stdev*float(args.exclude_threshold):
                print ("# %s"%mems[n])
    if not args.analyse: return
    for (x,n) in dist_from_ave:
        print (x, n2id[n])


def splitIntoGroups(n2id,group):
    mems =  {}
    for nid in n2id:
        g = group.get(nid,"NONE")
        if args.phe and g == "NONE": print("%s has group NONE"%(nid))
        if g in mems:
            mems[g].append(nid)
        else:
            mems[g]=[nid]
    for k in mems.keys():
        print("Group %s has %d members"%(k,len(mems[k])))
    return mems



def pcwiseRemove(q,args,mems):
    remove = set()
    for numiter in range(args.numoutlieriter):
        this_round=set()
        print("Iteration ",numiter)
        for pc in range(len(q[0])):
           print("  PC ",pc+1)
           # compute the centre point
           centre=0
           for id in mems:
               if id in remove: continue
               centre=centre+q[id2n[id]][pc]
           centre = centre/len(mems)
           # get average distance
           dist = map(lambda m : abs(q[id2n[m]][pc]-centre), mems)
           ave = 0
           for i,d in enumerate(dist):
               if mems[i] in remove: continue
               ave += d
           ave = ave/len(mems)
           # get stdev
           stdev = 0
           for i,d in enumerate(dist):
               if mems[i] in remove: continue
               stdev += (d-ave)**2
           stdev = math.sqrt(stdev/len(mems))
           print("    centre=%f, ave_dist=%f, stdev=%5.2f"%(centre,ave,stdev))
           for i, m in enumerate(mems):
               if m in remove | this_round: continue
               if dist[i] >  args.exclude_threshold*stdev:
                   this_round.add(m)
                   print("     Exclude: %s"%m)
        remove = remove | this_round
    return remove

args = setup()


print("\nLog of actions\n")
(weights,q,all_ids)   =  readQ(args)
#sum_weights         =  reduce (lambda x,y: x+y, weights)
(n2id,id2n)             = readFam(args,q,all_ids)
group                    = getGroup(args)
group_mems          = splitIntoGroups(n2id,group)
print()
remove=set()
if args.analyse or args.exclude_threshold:
    for group_set in args.group.split("/"):
        print("Analysing group %s\n%s"%(group_set,"-"*10))
        mems=[]
        for group in group_set.split(","):
            if group not in group_mems:
                sys.exit("Problem -- there is no group <%s> in the phenoype file"%group)
            mems = group_mems[group]+mems
        if args.euclidean:
            remove=remove|euclidRemove(q,args,mems)
        else:
            remove=remove|pcwiseRemove(q,args,mems)
    if args.out != sys.stdout:
        print(2)
        args.out = open(args.out,"w")
    for name in remove:
          args.out.write("%s\n"%name)
    args.out.close()
