#!/usr/bin/env python

from __future__ import division
from optparse import OptionParser
import logging,time,random,sys
import networkx as nx
import utilities as util
import k_shortcuts as KS

random.seed(10301949)

UNREACHABLE_PENALTY=100 # Dist between two unreachable nodes in the graph
DEFAULT_EDGE_WEIGHT=-1  # [0,1]: default weight for edges not in STRING
USE_POTENTIAL=-1        # {True,False}: whether to use potential edges
OBJ_FUNCTION=-1         # {shortcuts-x,shortcuts-x-ss}: restricted or restricted ss
MAX_DISTANCE=5          # Maximum allowable number of hops between a s-t pair

# Pretty print functions.
pps = lambda x,y: "(s)" if x in y else ""
ppt = lambda x,z: "(t)" if x in z else ""


#==============================================================================
#                                  HELPERS
#==============================================================================
def compute_cost(Ds,S,T):
    """ Computes the cost of the objective function. """
    if OBJ_FUNCTION == "shortcuts-x":
        return sum([Ds[s][MAX_DISTANCE][t] for s in S for t in T])
    else:
        return sum([min([Ds[s][MAX_DISTANCE][t] for s in S]) for t in T])


def compute_cost_uv_h4(Ds,Dt,u,v,Duv,S,T):
    """ Computes the cost of the objective function with edge u->v.
        Assumes hops = 4
    """
    new_cost = 0

    if OBJ_FUNCTION == "shortcuts-x":
        for s in S:
            for t in T:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 2t
                if Ds[s][1][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][2][v]

                # Case 2: s2, uv, 1t
                if Ds[s][2][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][1][v]

                # Case 3: s3, uv
                if v == t:
                    if Ds[s][3][u] + Duv < lowest:
                        lowest = Ds[s][3][u] + Duv

                # Case 4: uv, 3t
                if u == s:
                    if Duv + Dt[t][3][v] < lowest:
                        lowest = Duv + Dt[t][3][v]

                new_cost += lowest
    else: # shortcuts-x-ss
        for t in T:
            lowest_s = 1000

            for s in S:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 2t
                if Ds[s][1][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][2][v]

                # Case 2: s2, uv, 1t
                if Ds[s][2][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][1][v]

                # Case 3: s3, uv
                if v == t:
                    if Ds[s][3][u] + Duv < lowest:
                        lowest = Ds[s][3][u] + Duv

                # Case 4: uv, 3t
                if u == s:
                    if Duv + Dt[t][3][v] < lowest:
                        lowest = Duv + Dt[t][3][v]

                # Check if this s beats out all the previous ones.
                if lowest < lowest_s:
                    lowest_s = lowest

            new_cost += lowest_s

    return new_cost





def compute_cost_uv(Ds,Dt,u,v,Duv,S,T):
    """ Computes the cost of the objective function with edge u->v.
        Asummes max hops = 5
    """
    new_cost = 0

    if OBJ_FUNCTION == "shortcuts-x":
        for s in S:
            for t in T:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 3t
                if Ds[s][1][u] + Duv + Dt[t][3][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][3][v]

                # Case 2: s2, uv, 2t
                if Ds[s][2][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][2][v]

                # Case 3: s3, uv, 1t
                if Ds[s][3][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][3][u] + Duv + Dt[t][1][v]

                # Case 4: s4, uv
                if v == t:
                    if Ds[s][4][u] + Duv < lowest:
                        lowest = Ds[s][4][u] + Duv

                # Case 5: uv, 4t
                if u == s:
                    if Duv + Dt[t][4][v] < lowest:
                        lowest = Duv + Dt[t][4][v]

                new_cost += lowest
    else: # shortcuts-x-ss
        for t in T:
            lowest_s = 1000

            for s in S:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 3t
                if Ds[s][1][u] + Duv + Dt[t][3][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][3][v]

                # Case 2: s2, uv, 2t
                if Ds[s][2][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][2][v]

                # Case 3: s3, uv, 1t
                if Ds[s][3][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][3][u] + Duv + Dt[t][1][v]

                # Case 4: s4, uv
                if v == t:
                    if Ds[s][4][u] + Duv < lowest:
                        lowest = Ds[s][4][u] + Duv

                # Case 5: uv, 4t
                if u == s:
                    if Duv + Dt[t][4][v] < lowest:
                        lowest = Duv + Dt[t][4][v]

                # Check if this s beats out all the previous ones.
                if lowest < lowest_s:
                    lowest_s = lowest

            new_cost += lowest_s

    return new_cost


def compute_usage_uv_h4(Ds,Dt,u,v,Duv,S,T):
    """ Computes the usage of the edge u->v, defined as the number of hop-restricted
        shrot paths in which u->v is beneficial. Used by "betweenness".
        Assumes hops = 4
    """
    new_usage = 0

    if OBJ_FUNCTION == "shortcuts-x":
        for s in S:
            for t in T:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 2t
                if Ds[s][1][u] + Duv + Dt[t][2][v] < lowest: new_usage += 1

                # Case 2: s2, uv, 1t
                elif Ds[s][2][u] + Duv + Dt[t][1][v] < lowest: new_usage += 1

                # Case 3: s3, uv
                elif v == t:
                    if Ds[s][3][u] + Duv < lowest: new_usage += 1

                # Case 4: uv, 3t
                elif u == s:
                    if Duv + Dt[t][3][v] < lowest: new_usage += 1

    else: # "singles"
        # only count usage if it helps in the shortest s-t path (over all s).
        for t in T:
            lowest_s = 1000 # lowest with adding u->v.
            curr_lowest = 1000 # lowest without adding any edges (changes across s's)

            for s in S:
                curr_lowest = min(curr_lowest,Ds[s][MAX_DISTANCE][t])
                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 2t
                if Ds[s][1][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][2][v]

                # Case 2: s2, uv, 1t
                if Ds[s][2][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][1][v]

                # Case 3: s3, uv
                if v == t:
                    if Ds[s][3][u] + Duv < lowest:
                        lowest = Ds[s][3][u] + Duv

                # Case 4: uv, 3t
                if u == s:
                    if Duv + Dt[t][3][v] < lowest:
                        lowest = Duv + Dt[t][3][v]

                # Check if this s beats out all the previous ones.
                if lowest < lowest_s:
                    lowest_s = lowest

            if lowest_s < curr_lowest: new_usage += 1

    return new_usage


def compute_usage_uv(Ds,Dt,u,v,Duv,S,T):
    """ Computes the usage of the edge u->v, defined as the number of hop-restricted
        shrot paths in which u->v is beneficial. Used by "betweenness".
        Assumes hops = 5
    """
    new_usage = 0

    if OBJ_FUNCTION == "shortcuts-x":
        for s in S:
            for t in T:

                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 3t
                if Ds[s][1][u] + Duv + Dt[t][3][v] < lowest: new_usage += 1

                # Case 2: s2, uv, 2t
                elif Ds[s][2][u] + Duv + Dt[t][2][v] < lowest: new_usage += 1

                # Case 3: s3, uv, 1t
                elif Ds[s][3][u] + Duv + Dt[t][1][v] < lowest: new_usage += 1

                # Case 4: s4, uv
                elif v == t:
                    if Ds[s][4][u] + Duv < lowest: new_usage += 1

                # Case 5: uv, 4t
                elif u == s:
                    if Duv + Dt[t][4][v] < lowest: new_usage += 1

    else: # "singles"
        # only count usage if it helps in the shortest s-t path (over all s).
        for t in T:
            lowest_s = 1000 # lowest with adding u->v.
            curr_lowest = 1000 # lowest without adding any edges (changes across s's)

            for s in S:
                curr_lowest = min(curr_lowest,Ds[s][MAX_DISTANCE][t])
                lowest = Ds[s][MAX_DISTANCE][t]

                # Case 1: s1, uv, 3t
                if Ds[s][1][u] + Duv + Dt[t][3][v] < lowest:
                    lowest = Ds[s][1][u] + Duv + Dt[t][3][v]

                # Case 2: s2, uv, 2t
                if Ds[s][2][u] + Duv + Dt[t][2][v] < lowest:
                    lowest = Ds[s][2][u] + Duv + Dt[t][2][v]

                # Case 3: s3, uv, 1t
                if Ds[s][3][u] + Duv + Dt[t][1][v] < lowest:
                    lowest = Ds[s][3][u] + Duv + Dt[t][1][v]

                # Case 4: s4, uv
                if v == t:
                    if Ds[s][4][u] + Duv < lowest:
                        lowest = Ds[s][4][u] + Duv

                # Case 5: uv, 4t
                if u == s:
                    if Duv + Dt[t][4][v] < lowest:
                        lowest = Duv + Dt[t][4][v]

                # Check if this s beats out all the previous ones.
                if lowest < lowest_s:
                    lowest_s = lowest

            if lowest_s < curr_lowest: new_usage += 1

    return new_usage


def bellman_ford(G,S,T):
    """
        Bellman-Ford algorithm to compute hop-restricted shortest path distances.
        Returns two dictionaries.
            Ds[s][r][u] is the distance from source s to node u using <= r hops.
            Dt[t][r][u] is the distance from target t to node u using <= r hops.
    """
    r = MAX_DISTANCE

    def bf_helper(G,s):
        """ Computes BF distance from s to every node in G. """
        Dist = {} # Dist[r][u] = distance from s to u using <= r hops.

        for h in xrange(0,r+1): Dist[h] = {}

        for i in G: Dist[0][i] = UNREACHABLE_PENALTY
        Dist[0][s] = 0

        for h in xrange(1,r+1):

            for j in G: Dist[h][j] = Dist[h-1][j]

            for i,j in G.edges_iter():
                if Dist[h-1][i] + G[i][j].get('weight') < Dist[h][j]:
                    Dist[h][j] = Dist[h-1][i] + G[i][j].get('weight')

        return Dist

    Ds = {}
    Dt = {}
    for s in S: Ds[s] = bf_helper(G,s)
    for t in T: Dt[t] = bf_helper(G.reverse(),t)

    # Ensure forward lengths from s->t are equal to backward lengths from t->s
    for s in S:
        for t in T:
            for x in xrange(1,MAX_DISTANCE+1):
                assert abs(Ds[s][x][t]-Dt[t][x][s]) < 0.0001

    return (Ds,Dt)


def header(algorithm,k,best_cost):
    """ Prints header information for output file. """
    print "#OBJ FUNCTION=%s" %(OBJ_FUNCTION)
    print "#ALGORITHM=%s" %(algorithm)
    print "#K=%i" %(k)
    if USE_POTENTIAL: print "#POTENTIAL=1"
    else:
        print "#POTENTIAL=0"
        print "#DEFAULT EDGE WEIGHT=%.3f" %(DEFAULT_EDGE_WEIGHT)
    print "#INITIAL COST=%.5f" %(best_cost)
    print "## SRC TGT TIES DUV COST"


#==============================================================================
#                                 ALGORITHMS
#==============================================================================
def run(G,S,T,k,ALG):
    """ Greedy, Betweenness, and Direct-ST algorithms to predict k edges. """
    if USE_POTENTIAL:
        assert ALG == "Greedy"
        Gp = util.read_potential_edges(G)

    i = 1
    while i <= k:
        logging.info("Computing Bellman-Ford distances...")
        Ds,Dt = bellman_ford(G,S,T)

        # Get initial distance from sources to targets.
        best_cost = compute_cost(Ds,S,T)
        logging.info("  cost=%.3f" %(best_cost))
        if not (best_cost > 0):
            logging.info("Cost is %.3f. Exiting early..." %(best_cost))
            break

        if i == 1: header(ALG,k,best_cost)
        if USE_POTENTIAL: logging.info("Trying each potential edge...")
        else: logging.info("Trying each possible edge...")

        best_pairs = []
        if ALG == "Betweenness": best_usage = 0
        for u in G:
            if ALG == "Direct-ST" and u not in S: continue
            for v in G:
                if u == v: continue
                if G.has_edge(u,v): continue
                if ALG == "Direct-ST" and v not in T: continue

                if USE_POTENTIAL:
                    if not Gp.has_edge(u,v): continue
                    Duv = Gp[u][v].get('weight')
                    #Duv = 0 # TODO: reviewers.
                else:
                    Duv = DEFAULT_EDGE_WEIGHT

                assert 0 <= Duv and Duv <= 1



                # Compute cost if u->v exists.
                if ALG == "Greedy" or ALG == "Direct-ST":
                    new_cost = compute_cost_uv(Ds,Dt,u,v,Duv,S,T)
                    if abs(new_cost-best_cost) < 0.0001:
                        best_pairs.append((u,v,Duv))
                    elif new_cost < best_cost:
                        best_cost = new_cost
                        best_pairs = [(u,v,Duv)]
                elif ALG == "Betweenness":
                    new_usage = compute_usage_uv(Ds,Dt,u,v,Duv,S,T)

                    if new_usage > best_usage:
                        best_usage = new_usage
                        best_pairs = [(u,v,Duv)]
                    elif new_usage == best_usage:
                        best_pairs.append((u,v,Duv))
                else: assert False


        logging.info("  %i. # of ties: %i" %(i,len(best_pairs)))
        best_u,best_v,Duv = random.choice(best_pairs)
        assert not G.has_edge(best_u,best_v)

        if ALG == "Greedy" or ALG == "Direct-ST":
            print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,best_cost)
        elif ALG == "Betweenness":
            best_cost = compute_cost_uv(Ds,Dt,best_u,best_v,Duv,S,T)
            print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f\t%i" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,best_cost,best_usage)
        else: assert False

        logging.info("  Adding: %s%s%s->%s%s%s: %.3f; cost=%.3f" %(best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),Duv,best_cost))

        G.add_edge(best_u,best_v,weight=Duv)
        if USE_POTENTIAL: Gp.remove_edge(best_u,best_v) # remove from potential edges.
        i += 1
        Ds.clear()
        Dt.clear()



def run_global(G,S,T,k,ALG):
    """ Shorest-path and Jaccard algorithms to predict k edges.
        Global solutions; does not use S nor T.
    """
    assert not USE_POTENTIAL

    i = 1
    while i <= k:
        logging.info("Computing Bellman-Ford distances...")
        Ds,Dt = bellman_ford(G,S,T)

        if ALG == "Shortest-Paths":
            logging.info("Computing all-pairs shortest path lengths...")
            Dist = KS.apsp(G)

        curr_cost = compute_cost(Ds,S,T)
        logging.info("  cost=%.3f" %(curr_cost))
        if not (curr_cost > 0):
            logging.info("Cost is %.3f. Exiting early..." %(curr_cost))
            break

        if i == 1: header(ALG,k,curr_cost)
        if USE_POTENTIAL: logging.info("Trying each potential edge...")
        else: logging.info("Trying each possible edge...")

        best_pairs = []
        if ALG == "Shortest-Paths":  best_cost = 10 # dist between closest proteins
        elif ALG == "Weighted-Jaccard": best_cost = -1 # neighborhood similarity

        for u in G:
            for v in G:
                if u == v: continue
                if G.has_edge(u,v): continue

                Duv = DEFAULT_EDGE_WEIGHT
                assert 0 <= Duv and Duv <= 1

                if ALG == "Shortest-Paths":
                    new_cost = Dist[u][v]
                    if abs(new_cost-best_cost) < 0.00001:
                        best_pairs.append((u,v,Duv))
                    elif new_cost < best_cost:
                        best_cost = new_cost
                        best_pairs = [(u,v,Duv)]

                elif ALG == "Weighted-Jaccard":
                    # Weighted Jaccard(u,v): sum shared weights*# common neighbors /
                    #   sum of all weights for all neighbors from both u and v.
                    #   Transform w(u,v) to 1-w(u,v) so higher is better
                    Nv = set(G.neighbors(v))
                    #Wu= sum([1-G[u][x].get('weight') for x in G.neighbors(u) if x!=v])
                    #Wv= sum([1-G[v][x].get('weight') for x in G.neighbors(v) if x!=u])
                    Wu= sum([1-G[u][x].get('weight') for x in G.neighbors(u)])
                    Wv= sum([1-G[v][x].get('weight') for x in G.neighbors(v)])

                    if Wu+Wv==0: # no outgoing edges from these nodes.
                        new_cost = -1
                    else:
                        num_common = 0
                        Wuv = 0
                        for x in G.neighbors(u):
                            #if x in Nv and x!=u and x!=v:
                            if x in Nv:
                                Wuv += 1-G[u][x].get('weight')+1-G[v][x].get('weight')
                                num_common += 1

                        new_cost = (Wuv*num_common)/(Wu+Wv)

                        if abs(new_cost-best_cost) < 0.00001:
                            best_pairs.append((u,v,Duv))
                        elif new_cost > best_cost:
                            best_cost = new_cost
                            best_pairs = [(u,v,Duv)]

                else: assert False


        logging.info("  %i. # of ties: %i" %(i,len(best_pairs)))
        best_u,best_v,Duv = random.choice(best_pairs)
        assert not G.has_edge(best_u,best_v)

        G.add_edge(best_u,best_v,weight=Duv)
        if ALG == "Shortest-Paths": Dist.clear()

        curr_cost = compute_cost_uv(Ds,Dt,best_u,best_v,Duv,S,T)
        logging.info("  Adding: %s%s%s->%s%s%s: %.3f; cost=%.3f" %(best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),Duv,curr_cost))
        #print i,best_u,best_v,len(best_pairs),Duv,curr_cost,best_cost
        print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f\t%.3f" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,curr_cost,best_cost)

        i += 1
        del best_pairs # many ties, so this list gets huge.



#==============================================================================
#                                      MAIN
#==============================================================================
def main():
    start = time.time()
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(levelname)s: %(asctime)s -- %(message)s'
    )

    usage="usage: %prog [options]"
    parser = OptionParser(usage=usage)
    parser.add_option("-a", "--algorithm", action="store", type="string", dest="algorithm", default="greedy",help="prediction algorithm: greedy [default], betweenness, directst, sp, jaccard.")
    parser.add_option("-p", "--potential", action="store_true", default=False,dest="use_potential",help="use potential edges.")
    parser.add_option("-d", "--dew", action="store", type="string", dest="default_edge_weight", default="-1",help="default edge weight for predicted edges: [0,1]")
    parser.add_option("-k", "--k", action="store", type="int", dest="k", default="50",help="number of edges to predict.")
    parser.add_option("-o", "--objective", action="store", type="string", dest="obj_function", default="shortcuts-x",help="objective function: shortcuts-x [default] or shortcuts-x-ss.")

    (options, args) = parser.parse_args()
    # ===============================================================


    # ===============================================================
    algorithm = options.algorithm
    k = options.k
    global USE_POTENTIAL,DEFAULT_EDGE_WEIGHT,OBJ_FUNCTION
    USE_POTENTIAL = options.use_potential
    DEFAULT_EDGE_WEIGHT = float(options.default_edge_weight)
    OBJ_FUNCTION = options.obj_function

    net_file = "../data/protein.actions.v9.0.txt_YEAST_BINDING_FORMATTED_ORIENTED_3x_11targets"
    src_file = "../data/hogSources_5.txt"
    tgt_file = "../data/hogTargets_11.txt"

    logging.info("Reading input data...")
    G = util.read_network(net_file)
    S = util.read_nodes(src_file)
    T = util.read_nodes(tgt_file)

    for s in S: assert s in G
    for t in T: assert t in G

    logging.info("Oriented network: %s" %(net_file))
    logging.info("  #components=%i" %(nx.number_strongly_connected_components(G)))
    logging.info("  #nodes=%i, #edges=%i" %(G.order(),G.size()))
    logging.info("  #srcs=%i, #tgts=%i" %(len(S),len(T)))


    # ===============================================================
    # Compute stats for unreachable s-t pairs in the default network.
    num_unreachable = 0
    for s in S:
        for t in T:
            try:
                d = nx.dijkstra_path_length(G,s,t)
            except nx.exception.NetworkXNoPath:
                num_unreachable += 1
    logging.info("  #unreachable s-t pairs=%i" %(num_unreachable))

    if not (OBJ_FUNCTION == "shortcuts-x" or OBJ_FUNCTION == "shortcuts-x-ss"):
        logging.critical("INVALID OBJ FUNCTION: %s" %(OBJ_FUNCTION))
        sys.exit(1)

    logging.info("Obj function: %s" %(OBJ_FUNCTION))
    logging.info("Algorithm: %s" %(algorithm))
    logging.info("Adding k=%i edges" %(k))

    if USE_POTENTIAL:
        assert DEFAULT_EDGE_WEIGHT == -1
    else:
        assert 0<= DEFAULT_EDGE_WEIGHT <= 1
        logging.info("Default edge weight: %.3f" %(DEFAULT_EDGE_WEIGHT))

    if algorithm == "greedy": run(G,S,T,k,"Greedy")
    elif algorithm == "betweenness": run(G,S,T,k,"Betweenness")
    elif algorithm == "directst": run(G,S,T,k,"Direct-ST")
    elif algorithm == "sp": run_global(G,S,T,k, "Shortest-Paths")
    elif algorithm == "jaccard": run_global(G,S,T,k,"Weighted-Jaccard")
    else:
        logging.critical("INVALID ALGORITHM: %s" %(algorithm))
        sys.exit(1)


    # =========================== Finish ============================
    tot_time = (time.time()-start)/60
    logging.info("Time to run: %.3f (mins)" %(tot_time))
    print "#TIME TO RUN: %.3f (mins)" %(tot_time)


if __name__ == "__main__":
    main()
    #import cProfile
    #cProfile.run('main()')
