#!/usr/bin/env python

from __future__ import division
from optparse import OptionParser
import logging,time,random,sys
import networkx as nx
import utilities as util

random.seed(10301949)

UNREACHABLE_PENALTY=100 # Dist between two unreachable nodes in the graph
DEFAULT_EDGE_WEIGHT=-1  # [0,1]: default weight for edges not in STRING
USE_POTENTIAL=-1        # {True,False}: whether to use potential edges
OBJ_FUNCTION=-1         # {shortcuts,shortcuts-ss}: all s-t pairs or single source

# Pretty print functions.
pps = lambda x,y: "(s)" if x in y else ""
ppt = lambda x,z: "(t)" if x in z else ""


#==============================================================================
#                                 HELPERS
#==============================================================================
def apsp(G):
    """ Computes APSP for nodes in G.

        Returns double dictionary: u -> v-> length(u,v). (E.g. length[1][4])
        If a pair is unreachable, set its distance to UNREACHABLE_PENALTY.
    """

    Dist = nx.all_pairs_dijkstra_path_length(G)
    diameter = max([max(Dist[u].values()) for u in Dist])
    #logging.info("  Diameter=%.3f" %(diameter))

    # Set default distance for unreachable pairs. This is not strictly necessary,
    # but makes the code simpler below. Essentially, the cost of taking a link
    # between an unreachable pair is so high that it will never be used.
    # Of course, in the algorithm, an edge can be added between this pair.
    num_reachable = num_unreachable = 0
    for u in Dist:
        for v in G:
            if v not in Dist[u]:
                Dist[u][v] = UNREACHABLE_PENALTY
                num_unreachable += 1
            else:
                num_reachable += 1
    #logging.info("  #unreachable pairs=%i" %(num_unreachable))

    # Ensure all nchoose2 pairs are covered.
    x = sum(len(Dist[u]) for u in Dist)
    n = len(G)
    assert x == n*(n-1) + n # n choose 2 (+ symmetric + self)
    assert len(Dist) == G.order()
    assert num_unreachable + num_reachable == n*n

    return Dist


def compute_cost(Dist,S,T):
    """ Computes the cost of the objective function. """
    if OBJ_FUNCTION == "shortcuts":
        return sum([Dist[s][t] for s in S for t in T])
    else:
        return sum([min([Dist[s][t] for s in S]) for t in T])


def compute_cost_g(G,S,T):
    """ Computes the cost of the objective function from G. """
    if OBJ_FUNCTION == "shortcuts":
        return sum([nx.dijkstra_path_length(G,s,t) for s in S for t in T])
    else:
        return sum([min([nx.dijkstra_path_length(G,s,t) for s in S]) for t in T])


def compute_cost_uv(Dist,u,v,Duv,S,T):
    """ Computes the cost of the objective function with edge u->v. """
    if OBJ_FUNCTION == "shortcuts":
        return sum([min(Dist[s][u]+Duv+Dist[v][t],Dist[s][t]) for s in S for t in T])
    else:
        return sum([min([min(Dist[s][u]+Duv+Dist[v][t],Dist[s][t]) for s in S]) for t in T])

    return new_cost


def compute_usage_uv(Dist,u,v,Duv,S,T):
    """ Computes the usage of the edge u->v, defined as the number of shortest paths
        in which u->v is beneficial. Used by "betweenness".
    """
    if OBJ_FUNCTION == "shortcuts":
        new_usage = sum([1 if Dist[s][u]+Duv+Dist[v][t] < Dist[s][t] else 0 for s in S for t in T])

    else:
        # For each target, compare global cost before and global cost after u->v.
        new_usage = 0
        for t in T:
            lowest_s = 1000 # lowest with adding u->v
            curr_lowest = 1000 # lowest without adding any edges.

            for s in S:
                curr_lowest = min(curr_lowest,Dist[s][t])

                lowest = min(Dist[s][u]+Duv+Dist[v][t],Dist[s][t])
                if lowest < lowest_s:
                    lowest_s = lowest

            if lowest_s < curr_lowest:
                new_usage += 1

    return new_usage


def header(algorithm,k,best_cost):
    """ Prints header information for output file. """
    print "#OBJ FUNCTION=%s" %(OBJ_FUNCTION)
    print "#ALGORITHM=%s" %(algorithm)
    print "#K=%i" %(k)
    if USE_POTENTIAL: print "#POTENTIAL=1"
    else:
        print "#POTENTIAL=0"
        print "#DEFAULT EDGE WEIGHT=%.3f" %(DEFAULT_EDGE_WEIGHT)
    print "#INITIAL COST=%.5f" %(best_cost)
    print "## SRC TGT TIES DUV COST"

#==============================================================================
#                                 ALGORITHMS
#==============================================================================
def run(G,S,T,k,ALG):
    """ Greedy, Betweenness, and Direct-ST algorithms to predict k edges. """

    if USE_POTENTIAL:
        assert ALG == "Greedy"
        Gp = util.read_potential_edges(G)

    i = 1
    while i <= k:
        logging.info("Computing all-pairs shortest path lengths...")
        Dist = apsp(G)

        # Get initial distance from sources to targets.
        best_cost = compute_cost(Dist,S,T)
        logging.info("  cost=%.3f" %(best_cost))
        if not (best_cost > 0):
            logging.info("Cost is %.3f. Exiting early..." %(best_cost))
            break

        if i == 1: header(ALG,k,best_cost)
        if USE_POTENTIAL: logging.info("Trying each potential edge...")
        else: logging.info("Trying each possible edge...")

        best_pairs = []
        if ALG == "Betweenness": best_usage = 0
        for u in G:
            if ALG == "Direct-ST" and u not in S: continue
            for v in G:
                if u == v: continue
                if G.has_edge(u,v): continue
                if ALG == "Direct-ST" and v not in T: continue

                if USE_POTENTIAL:
                    if not Gp.has_edge(u,v): continue
                    Duv = Gp[u][v].get('weight')
                    #Duv = 0.0 # TODO: reviewer, use potential, but not its weight.
                else:
                    Duv = DEFAULT_EDGE_WEIGHT

                assert 0 <= Duv and Duv <= 1

                # Compute cost if u->v exists.
                if ALG == "Greedy" or ALG == "Direct-ST":
                    new_cost = compute_cost_uv(Dist,u,v,Duv,S,T)
                    if abs(new_cost-best_cost) < 0.00001:
                        best_pairs.append((u,v,Duv))
                    elif new_cost < best_cost:
                        best_cost = new_cost
                        best_pairs = [(u,v,Duv)]

                elif ALG == "Betweenness":
                    new_usage = compute_usage_uv(Dist,u,v,Duv,S,T)

                    if new_usage > best_usage:
                        best_usage = new_usage
                        best_pairs = [(u,v,Duv)]
                    elif new_usage == best_usage:
                        best_pairs.append((u,v,Duv))
                else: assert False


        logging.info("  %i. # of ties: %i" %(i,len(best_pairs)))
        best_u,best_v,Duv = random.choice(best_pairs)
        assert not G.has_edge(best_u,best_v)

        if ALG == "Greedy" or ALG == "Direct-ST":
            print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,best_cost)
        elif ALG == "Betweenness":
            best_cost = compute_cost_uv(Dist,best_u,best_v,Duv,S,T)
            print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f\t%i" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,best_cost,best_usage)
        else: assert False

        logging.info("  Adding: %s%s%s->%s%s%s: %.3f; cost=%.3f" %(best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),Duv,best_cost))

        G.add_edge(best_u,best_v,weight=Duv)
        if USE_POTENTIAL: Gp.remove_edge(best_u,best_v) # remove from potential edges
        i += 1
        Dist.clear()



def run_global(G,S,T,k,ALG):
    """ Shorest-path and Jaccard algorithms to predict k edges.
        Global solutions; does not use S nor T.
    """
    assert not USE_POTENTIAL

    i = 1
    while i <= k:
        if ALG == "Shortest-Paths":
            logging.info("Computing all-pairs shortest path lengths...")
            Dist = apsp(G)
            curr_cost = compute_cost(Dist,S,T)
        elif ALG == "Weighted-Jaccard":
            curr_cost = compute_cost_g(G,S,T)

        logging.info("  cost=%.3f" %(curr_cost))
        if not (curr_cost > 0):
            logging.info("Cost is %.3f. Exiting early..." %(curr_cost))
            break

        if i == 1: header(ALG,k,curr_cost)
        if USE_POTENTIAL: logging.info("Trying each potential edge...")
        else: logging.info("Trying each possible edge...")

        best_pairs = []
        if ALG == "Shortest-Paths":  best_cost = 10 # dist between closest proteins
        elif ALG == "Weighted-Jaccard": best_cost = -1 # neighborhood similarity

        for u in G:
            for v in G:
                if u == v: continue
                if G.has_edge(u,v): continue

                Duv = DEFAULT_EDGE_WEIGHT
                assert 0 <= Duv and Duv <= 1

                if ALG == "Shortest-Paths":
                    new_cost = Dist[u][v]
                    if abs(new_cost-best_cost) < 0.00001:
                        best_pairs.append((u,v,Duv))
                    elif new_cost < best_cost:
                        best_cost = new_cost
                        best_pairs = [(u,v,Duv)]

                elif ALG == "Weighted-Jaccard":
                    # Weighted Jaccard(u,v): sum shared weights*# common neighbors /
                    #   sum of all weights for all neighbors from both u and v.
                    #   Transform w(u,v) to 1-w(u,v) so higher is better
                    Nv = set(G.neighbors(v))
                    #Wu= sum([1-G[u][x].get('weight') for x in G.neighbors(u) if x!=v])
                    #Wv= sum([1-G[v][x].get('weight') for x in G.neighbors(v) if x!=u])
                    Wu= sum([1-G[u][x].get('weight') for x in G.neighbors(u)])
                    Wv= sum([1-G[v][x].get('weight') for x in G.neighbors(v)])

                    if Wu+Wv==0: # no outgoing edges from these nodes.
                        new_cost = -1
                    else:
                        num_common = 0
                        Wuv = 0
                        for x in G.neighbors(u):
                            #if x in Nv and x!=u and x!=v:
                            if x in Nv:
                                Wuv += 1-G[u][x].get('weight')+1-G[v][x].get('weight')
                                num_common += 1

                        #Wuv = sum([(G[u][x].get('weight')+G[v][x].get('weight'))/2 for x in G.neighbors(u) if x in Nv and x!=u and x!=v])
                        #new_cost = (Wuv)/(Wu+Wv)
                        new_cost = (Wuv*num_common)/(Wu+Wv)

                        if abs(new_cost-best_cost) < 0.00001:
                            best_pairs.append((u,v,Duv))
                        elif new_cost > best_cost:
                            best_cost = new_cost
                            best_pairs = [(u,v,Duv)]

                else: assert False

        logging.info("  %i. # of ties: %i" %(i,len(best_pairs)))
        best_u,best_v,Duv = random.choice(best_pairs)
        assert not G.has_edge(best_u,best_v)

        G.add_edge(best_u,best_v,weight=Duv)
        if ALG == "Shortest-Paths":
            curr_cost = compute_cost_uv(Dist,best_u,best_v,Duv,S,T)
            Dist.clear()
        elif ALG == "Weighted-Jaccard":
            curr_cost = compute_cost_g(G,S,T)

        logging.info("  Adding: %s%s%s->%s%s%s: %.3f; cost=%.3f" %(best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),Duv,curr_cost))
        #print i,best_u,best_v,len(best_pairs),Duv,curr_cost,best_cost
        print "%i\t%s%s%s\t%s%s%s\t%i\t%.3f\t%.3f\t%.3f" %(i,best_u,pps(best_u,S),ppt(best_u,T),best_v,pps(best_v,S),ppt(best_v,T),len(best_pairs),Duv,curr_cost,best_cost)

        i += 1
        del best_pairs # many ties, so this list gets huge.


#==============================================================================
#                                      MAIN
#==============================================================================
def main():
    start = time.time()
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(levelname)s: %(asctime)s -- %(message)s'
    )

    usage="usage: %prog [options]"
    parser = OptionParser(usage=usage)
    parser.add_option("-a", "--algorithm", action="store", type="string", dest="algorithm", default="greedy",help="prediction algorithm: greedy [default], betweenness, sp, jaccard or directst.")
    parser.add_option("-p", "--potential", action="store_true", default=False,dest="use_potential",help="use potential edges.")
    parser.add_option("-d", "--dew", action="store", type="string", dest="default_edge_weight", default="-1",help="default edge weight for predicted edges: [0,1]")
    parser.add_option("-k", "--k", action="store", type="int", dest="k", default="50",help="number of edges to predict.")
    parser.add_option("-o", "--objective", action="store", type="string", dest="obj_function", default="shortcuts",help="objective function: shortcuts [default] or shortcuts-ss.")

    (options, args) = parser.parse_args()
    # ===============================================================


    # ===============================================================
    algorithm = options.algorithm
    k = options.k
    global USE_POTENTIAL,DEFAULT_EDGE_WEIGHT,OBJ_FUNCTION
    USE_POTENTIAL = options.use_potential
    DEFAULT_EDGE_WEIGHT = float(options.default_edge_weight)
    OBJ_FUNCTION = options.obj_function

    net_file = "../data/protein.actions.v9.0.txt_YEAST_BINDING_FORMATTED_ORIENTED_3x_11targets"
    src_file = "../data/hogSources_5.txt"
    tgt_file = "../data/hogTargets_11.txt"

    logging.info("Reading input data...")
    G = util.read_network(net_file)
    S = util.read_nodes(src_file)
    T = util.read_nodes(tgt_file)

    for s in S: assert s in G
    for t in T: assert t in G

    logging.info("Oriented network: %s" %(net_file))
    logging.info("  #components=%i" %(nx.number_strongly_connected_components(G)))
    logging.info("  #nodes=%i, #edges=%i" %(G.order(),G.size()))
    logging.info("  #srcs=%i, #tgts=%i" %(len(S),len(T)))


    # ===============================================================
    # Compute stats for unreachable s-t pairs in the default network.
    num_unreachable = 0
    for s in S:
        for t in T:
            try:
                d = nx.dijkstra_path_length(G,s,t)
            except nx.exception.NetworkXNoPath:
                num_unreachable += 1
    logging.info("  #unreachable s-t pairs=%i" %(num_unreachable))

    if not (OBJ_FUNCTION == "shortcuts" or OBJ_FUNCTION == "shortcuts-ss"):
        logging.critical("INVALID OBJ FUNCTION: %s" %(OBJ_FUNCTION))
        sys.exit(1)

    logging.info("Obj function: %s" %(OBJ_FUNCTION))
    logging.info("Algorithm: %s" %(algorithm))
    logging.info("Adding k=%i edges" %(k))

    if USE_POTENTIAL:
        assert DEFAULT_EDGE_WEIGHT == -1
    else:
        assert 0<= DEFAULT_EDGE_WEIGHT <= 1
        logging.info("Default edge weight: %.3f" %(DEFAULT_EDGE_WEIGHT))

    if algorithm == "greedy": run(G,S,T,k,"Greedy")
    elif algorithm == "betweenness": run(G,S,T,k,"Betweenness")
    elif algorithm == "directst": run(G,S,T,k,"Direct-ST")
    elif algorithm == "sp": run_global(G,S,T,k, "Shortest-Paths")
    elif algorithm == "jaccard": run_global(G,S,T,k,"Weighted-Jaccard")
    else:
        logging.critical("INVALID ALGORITHM: %s" %(algorithm))
        sys.exit(1)


    # =========================== Finish ============================
    tot_time = (time.time()-start)/60
    logging.info("Time to run: %.3f (mins)" %(tot_time))
    print "#TIME TO RUN: %.3f (mins)" %(tot_time)


if __name__ == "__main__":
    main()
