# coding=utf-8
import argparse

import numpy as np
import pandas as pd

pd.set_option('display.max_rows', None)

import sys

import tqdm
import pickle
import os
import random
import copy
import json
import torch
import torch.autograd as autograd
import torch.optim as optim
from torch.distributions import Categorical, Normal, MixtureSameFamily
from torch.distributions.uniform import Uniform
import matplotlib.pyplot as plt
import multiprocessing as mp

from computeSimilarity import cartesian_to_polar_matrix, \
    polar_to_cartesian_matrix, cos_simc_matrix, euclidean_simc_matrix


np.random.seed(0)
torch.manual_seed(0)



def load_height_file(file):
    '''
    An example of graphs and values.

    [Graph]

           0
        ___|           5
       /   |           |
      1    2           6
      |    |           |
      3----4           7
           |
           10
         __|__
        /     \
       8      9


    [Values]

    | node | height | descendants | depth | ancestors |
    |:----:|:------:|:-----------:|:-----:|:---------:|
    |   0  |    5   |      7      |   0   |     0     |
    |   1  |    4   |      5      |   1   |     1     |
    |   2  |    3   |      4      |   1   |     1     |
    |   3  |    3   |      4      |   2   |     2     |
    |   4  |    2   |      3      |   3   |     4     |
    |   5  |    2   |      2      |   0   |     0     |
    |   6  |    1   |      1      |   1   |     1     |
    |   7  |    0   |      0      |   2   |     2     |
    |   8  |    0   |      0      |   5   |     6     |
    |   9  |    0   |      0      |   5   |     6     |
    |  10  |    1   |      2      |   4   |     5     |

    '''
    columns = ['id', 'height', 'descendants', 'depth', 'ancestors']
    df = pd.read_csv(file, delimiter='\t', names=columns)

    def radius1():
        ''' r = 2 - scale(h + log(|d|+1) + 1)

            h is height and |d| is number of descendants.
            The value range is [1, 2]. (Add 1 for avoiding norm = 0)

        '''
        descendants_log = np.log1p(df.descendants)
        scores = df.height + descendants_log

        mx = scores.max()
        mn = scores.min()

        scaled_scores = (scores - mn) / (mx - mn)

        return np.array(2 - scaled_scores)

    return radius1()


class RBF(torch.nn.Module):
    def __init__(self,  svgd_size,sigma=None):
        super(RBF, self).__init__()
        self.sigma = sigma

    def forward(self, X):
        XX = X.view(-1, 1)  # 1000,1
        dnorm2 = torch.pow(torch.cdist(XX, XX), 2)  # 1000,1000
        np_dnorm2 = dnorm2.detach().cpu().numpy()
        h = np.median(np_dnorm2) / (2 * np.log(XX.size(0) + 1))
        K_XY = torch.exp(-dnorm2 / (h + 1e-8))  # 1000,1000
        return K_XY


class SVGD:
    def __init__(self, P, K, size):
        self.P = P
        self.K = K
        self.size = size
        self.is_uniform = isinstance(P, Uniform)

    def grad(self, X):
        X = X.detach().requires_grad_(True)  
        log_prob = self.P.log_prob(X)  
        K_XX = self.K(X)  

        score_func = autograd.grad(log_prob.sum(), X)[0]

        grad_K = -autograd.grad(K_XX.sum(), X)[0]  

        phi = (K_XX.detach().matmul(score_func) + grad_K) / self.size  # grad

        return phi

    def fit(self, angles, d, sampling, probs=None, n_batch=32, n_iter=500, alpha=1.0,
            loss_min=1.0, validation_interval=10, earlyout='validation',
            validation_set=None, validation_earlyout_threashold=0.95):
        """

        angles: np.array of shape (batch), One of the angle dimensions
        n_batch: int, batch size
        n_iter: int, number of interation
        probs: np.array, i-th value represents probability of i-th word
        alpha: float, learning rate
        loss_min: float, loss threshold for early stopping
        validation_interval: int

        """
        angles_pt = torch.Tensor(angles)
        if sampling == 'log_prob':
            probs = self.P.log_prob(angles_pt)
            probs = torch.abs(-(probs - torch.max(probs)))
            probs = probs/torch.sum(probs)

        n = angles_pt.shape[0]
        if probs is None:
            probs = [1/n for _ in range(n)]
        probs = torch.Tensor(probs)

        tot_loss = 0
        # SVGD needs the same batch combination, otherwise it cannot be converged.
        permutation = torch.randperm(n)
        t = tqdm.trange(n_iter, desc=f'[SVGD] Dimension {d}', leave=False)
        for i in t:
            if sampling == 'bruteforce':
                if earlyout == 'validation':
                    _, _, fs, _ = validation_with_multi_params(*validation_set)
                    original_f = np.max(fs)
                    cur_max_f = original_f
                    eo_threshold = original_f * validation_earlyout_threashold
                    angles_np = validation_set[0].copy()
                    t.set_description('[SVGD] Dimension {} (Oirg F:{:.4f}, Earlyout F:{:.4f}, Current F:{:.4f})'.format(d, original_f, eo_threshold, cur_max_f))
                    t.refresh()

                is_end = False
                cum_loss = 0
                n_update = 0
                for bid in range(0, n, n_batch):
                    n_update += 1
                    indices = permutation[bid:bid+n_batch]
                    g = self.grad(angles_pt[indices])

                    angles_pt[indices] += alpha * g

                    cum_loss += torch.sum(torch.abs(g))

                if earlyout == 'validation':
                    angles_np[:, d] = angles_pt.cpu().numpy()

                    _, _, cur_f, _ = \
                        validation_with_multi_params(angles_np,
                                                     *validation_set[1:])
                    cur_max_f = np.max(cur_f)
                    t.set_description('[SVGD] Dimension {} (Oirg F:{:.4f}, Earlyout F:{:.4f}, Current F:{:.4f})'.format(d, original_f, eo_threshold, cur_max_f))
                    t.refresh()
                    # print(original_f, cur_max_f, eo_threshold)
                    if cur_max_f < eo_threshold:
                        is_end = True
                        break


                if is_end:
                    break

            else:
                raise NotImplementedError
                indices = torch.multinomial(probs, n_batch)
                #indices = torch.fmod(torch.range(n_batch*(i-1),n_batch*i),angles_pt.shape[0])
                g = self.grad(angles_pt[indices])
                angles_pt[indices] += g * alpha

                loss = torch.abs(g).sum().cpu().numpy()
                tot_loss += loss

                if sampling == 'log_prob':
                    probs = self.P.log_prob(angles_pt)
                    probs = torch.abs(-(probs - torch.max(probs)))
                    probs = probs/torch.sum(probs)

                if i % validation_interval == 0:
                    if tot_loss < loss_min*validation_interval*n_batch:
                        break
                    tot_loss = 0
        return angles_pt.cpu().numpy()


def svgd_mp_wrapper(args):
    svgd, angles, d, opt, w_prob, v_set = args

    return svgd.fit(
        angles,
        d=d,
        sampling=opt.svgd_sampling,
        probs=w_prob if opt.svgd_sampling == 'multinomial' else None,
        n_batch=opt.svgd_batch,
        n_iter=opt.svgd_epoch,
        alpha=opt.svgd_alpha,
        loss_min=opt.svgd_loss_min,
        validation_interval=opt.svgd_validation_interval,
        earlyout='validation',
        validation_set=v_set,
        validation_earlyout_threashold=opt.svgd_eo_vth
    )

def compute_grad_batch(target_angles, context_angles, p_or_n):
    '''

    target_angles: Angles of a target word
        np.array of shape (batch_size, angle_dim)
    context_angles: Angles of context words
        np.array of shape (batch_size, angle_dim)
    p_or_n: Polarity of context words
        str, 'p' if positive, otherwise 'n'

    '''
    assert p_or_n in ['p', 'n']

    # target_angle = np.reshape(target_angle, (1, target_angle.shape[1]))
    # angle_diffs [batch, dim] 
    angle_diffs = target_angles - context_angles  # (batch, dim)

    if p_or_n == 'p':
        directions = np.where(np.abs(angle_diffs) < 1, 1, -1)  # (batch, dim)
    else:
        directions = np.where(np.abs(angle_diffs) < 1, -1, 1)

    angle_diffs = np.where(np.abs(angle_diffs) > 1,
                           np.sign(angle_diffs) * (2 - np.abs(angle_diffs)),
                           angle_diffs)


    if opt.l2loss:
        grad = directions* 2* angle_diffs
    else:
        grad = directions *0.5* angle_diffs * np.exp(-np.power(angle_diffs, 2) /(2*opt.welsch_pm*opt.welsch_pm ))  

    return grad


def has_edge(angles, radiuses, x1s, x2s, c=0.5):
    a1s = angles[x1s, :]
    a2s = angles[x2s, :]
    r1s = radiuses[x1s]
    r2s = radiuses[x2s]

    ##cos_sim
    tmp = np.reshape(r1s, (r1s.shape[0], 1))
    m1 = polar_to_cartesian_matrix(
        np.concatenate([np.reshape(r1s, (r1s.shape[0], 1)), a1s], 1))
    m2 = polar_to_cartesian_matrix(
        np.concatenate([np.reshape(r2s, (r2s.shape[0], 1)), a2s], 1))
    a_values = cos_simc_matrix(m1, m2)

    
    r_values = np.abs(r2s - r1s)

    return a_values > (-c * (r_values*r_values)+1)


def validation(angles, rs, vx1s, vx2s, vy, sim_pm):
    pred = has_edge(angles, rs, vx1s, vx2s, sim_pm)
    precision = np.sum(pred[vy == 1]) / np.sum(pred)

    recall = np.sum(pred[(vy == 1)]) / np.sum(vy)

    f = 2 * (precision * recall) / (precision + recall)

    if np.isnan(f):
        f = 0.0

    acc = np.sum(vy == pred) / len(vy)
    return precision, recall, f, acc


def validation_with_multi_params(angles, rs, vx1s, vx2s, vy, sim_pms):
    p_results, r_results, f_results, a_results = [], [], [], []
    for sim_pm in sim_pms:

        precision, recall, f, acc = validation(angles, rs, vx1s, vx2s, vy, sim_pm)

        p_results.append(precision)
        r_results.append(recall)
        f_results.append(f)
        a_results.append(acc)
    return p_results, r_results, f_results, a_results


def main(opt):
    best_f = 0

    if opt.mammal:
        basedir = 'data/mammal'
        trainfile = os.path.join(basedir, 'mammal_closure.tsv')
        validfile_pos = os.path.join(basedir, 'mammal_closure.tsv')
        validfile_neg = validfile_pos + '_neg'
    else:
        basedir = 'data/noun/'
        trainfile = \
            os.path.join(basedir, f'noun_closure.tsv.train_{opt.t}percent')
        validfile_pos = os.path.join(basedir, 'noun_closure.tsv.valid')
        validfile_neg = validfile_pos + '_neg'

    assert os.path.exists(trainfile + '_vocab.pkl'), \
        f'Preprocess has not been done yet for {trainfile}'

    vocab = pickle.load(open(trainfile + '_vocab.pkl', 'rb'))
    freq = pickle.load(open(trainfile + '_freq.pkl', 'rb'))
    w_prob = freq / sum(freq)
    word2context = pickle.load(open(trainfile + '_context.pkl', 'rb'))
    if opt.rfile:
        values = []
        with open(opt.rfile) as f:
            values = [float(r) for r in f]
        assert len(values) == 82114
        rs = np.array(values)
    else:
        rs = load_height_file(trainfile + '.height')

    # uniform distribution on sphere
    tmp = np.random.normal(0, 1, (len(vocab), opt.dim))  
    tmp = tmp / (
        np.tile(np.sqrt(np.sum(np.power(tmp, 2), axis=1)), (opt.dim, 1))).T
    angles = cartesian_to_polar_matrix(tmp)[:, 1:]

    if opt.SVGD:
        with open('./dataset/angle_GMM_parameters_dim'+str(opt.dim)+'.json') as f:
            gmmParam = json.load(f)
            svgds = []
            for d in range(opt.dim - 1):
                w = torch.tensor(gmmParam[str(d)]['weight'])
                mu = torch.tensor(gmmParam[str(d)]['mean'])
                sigma = torch.tensor(gmmParam[str(d)]['sigma'])
                g_mixtures = \
                    MixtureSameFamily(Categorical(w), Normal(mu, sigma))
                svgds.append(SVGD(g_mixtures, RBF(svgd_size=opt.svgd_batch), len(vocab)))
    dropoutM = np.ones((angles.shape[0], angles.shape[1]))
    if opt.dropout:
        for d in range(opt.dim-2):
            dropoutM[random.sample(range(angles.shape[0]),round(opt.dropout_rate*angles.shape[0])),d] = opt.dropout_weights



    # load training data
    training_samples = []
    for wt, context in word2context.items():
        if len(context) == 0:
            continue
        for c in context:
            training_samples.append([wt, c])

    # load validation data
    validation_samples = []
    with open(validfile_pos) as pf, open(validfile_neg) as nf:
        for l in pf:
            validation_samples.append(list(map(int, l.strip().split())) + [1])
        for l in nf:
            validation_samples.append(list(map(int, l.strip().split())) + [0])

    vx1s, vx2s, vy = zip(*validation_samples)
    vx1s = np.array(vx1s)
    vx2s = np.array(vx2s)
    vy = np.array(vy)

    # thresholds
    sim_pms = [0.2,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,2.0]

    if opt.debug:
        # Use validation data for training in debug mode
        training_samples = [v[:-1] for v in validation_samples]

    # model name template
    data = 'mammal' if opt.mammal else 'noun{}'.format(opt.t)
    filepath = '{}_d{}_itr{}_batch{}_nr{}'.format(data, opt.dim, opt.iter,
                                                  opt.batch, opt.nratio)
    print('Save model to ', opt.savedir)
    if opt.output_log:
        log_filename = filepath+ '_'+'_'.join(['alpha'+str(opt.alpha),'alpha_decay'+str(opt.alpha_decay)])
        if opt.SVGD:
            log_filename+= '_SVGD'
        if opt.dropout:
            log_filename+= '_dropout'
        if opt.l2loss:
            log_filename+='_l2loss'
        if opt.rfile:
            log_filename+='_rfile'
        f_log = open(log_filename, "w", encoding = "utf_8")

    for i in tqdm.tqdm(range(1, opt.iter), desc='Number of updates'):

        wt, context = zip(*random.sample(training_samples, opt.batch))

        wt = list(wt)
        context = list(context)
        wt_angle = angles[wt, :]  # (batch, dim)
        wt_radius = rs[wt]  # (batch)

        cur_positives = context
        cur_negatives = np.random.choice(len(vocab),
                                         size=int(len(cur_positives) * 1),
                                         p=w_prob)
        wps_angles = angles[cur_positives, :]  # (batch, dim)
        wps_radiuses = rs[cur_positives]  # (batch)
        wns_angles = angles[cur_negatives, :]  # (batch, dim)
        wns_radiuses = rs[cur_negatives]  # (batch)

        # compute gradient
        grad_p_wps = compute_grad_batch(wt_angle, wps_angles, 'p')
        grad_n_wns = compute_grad_batch(wt_angle, wns_angles, 'n')

        
        adaptive_nratio = \
            opt.nratio * np.ones((grad_n_wns.shape[0], grad_n_wns.shape[1])) 
        adaptive_nratio[:, -1] = 1.0  # dim of (0,2pi)
        grad_pn_wt = (grad_p_wps + adaptive_nratio * grad_n_wns)

        # Update angles
        angles_hat = angles[wt] - (dropoutM[wt] * opt.alpha * grad_pn_wt)  # (batch, dim)

        # (0, 2pi) dim
        angles[wt, -1] = np.mod(angles_hat[:, -1], 2)

        # (0, pi) dims
        # If an angle overflows by update, it keeps an original value
        is_overflow = (angles_hat[:, :-1] < 0) | (angles_hat[:, :-1] > 1)
        angles[wt, :-1] = \
            np.where(is_overflow, angles[wt, :-1], angles_hat[:, :-1])

        if opt.SVGD and i % opt.svgd_interval == 0:
            if opt.svgd_plot:
                fig = plt.figure()
            n_svgd = i // opt.svgd_interval
            validation_set = [angles, rs, vx1s, vx2s, vy, sim_pms]
            if opt.svgd_workers == 0:
                for d in range(opt.dim - 1):
                    angles_d_hat = svgds[d].fit(
                        angles[:, d],
                        d=d,
                        sampling=opt.svgd_sampling,
                        probs=w_prob if opt.svgd_sampling == 'multinomial' else None,
                        n_batch=opt.svgd_batch,
                        n_iter=opt.svgd_epoch,
                        alpha=opt.svgd_alpha,
                        loss_min=opt.svgd_loss_min,
                        validation_interval=opt.svgd_validation_interval,
                        earlyout='validation',
                        validation_set=validation_set,
                        validation_earlyout_threashold=opt.svgd_eo_vth
                    )

                    

                    if opt.svgd_plot:
                        plt.hist(angles[:, d], bins=50)
                        saveto = f'figure/angles_before_{n_svgd}_{d}.png'
                        fig.savefig(saveto)
                        plt.clf()

                    if d == opt.dim - 2:  # (0, 2pi) dim
                        angles[:, d] = angles_d_hat
                        angles[:, d] = np.mod(angles[:, -1], 2)

                    else:  # (0, pi) dim
                        is_overflow = (angles_d_hat < 0) | (angles_d_hat > 1)
                        angles[:, d] = \
                            np.where(is_overflow, angles[:, d], angles_d_hat)

                    if opt.svgd_plot:
                        plt.hist(angles[:, d], bins=50)
                        saveto = f'figure/angles_after_{n_svgd}_{d}.png'
                        fig.savefig(saveto)
                        plt.clf()
            else:
                workers = opt.svgd_workers
                if workers == -1:
                    workers = mp.cpu_count()
                workers = min(opt.dim-1, workers)

                ds = list(range(opt.dim-1))

                inputs = [(svgds[d], angles[:, d], d, opt, w_prob, validation_set) for d in ds]

                with mp.Pool(workers) as p:
                    angles_d_hats = p.map(svgd_mp_wrapper, inputs)

                for d, angles_d_hat in zip(ds, angles_d_hats):

                    if opt.svgd_plot:
                        plt.hist(angles[:, d], bins=50)
                        saveto = f'figure/angles_before_{n_svgd}_{d}.png'
                        fig.savefig(saveto)
                        plt.clf()

                    if d == opt.dim - 2:  # (0, 2pi) dim
                        angles[:, d] = angles_d_hat
                        angles[:, d] = np.mod(angles[:, -1], 2)

                    else:  # (0, pi) dim
                        is_overflow = (angles_d_hat < 0) | (angles_d_hat > 1)
                        angles[:, d] = \
                            np.where(is_overflow, angles[:, d], angles_d_hat)
                    
                    if opt.svgd_plot:
                        plt.hist(angles[:, d], bins=50)
                        saveto = f'figure/angles_after_{n_svgd}_{d}.png'
                        fig.savefig(saveto)
                        plt.clf()


        # Check
        assert (angles >= 0).all(), f'{angles[angles < 0]}'
        assert (angles[:, :-1] <= 1).all(), f'{angles[angles[:, :-1] <= 1]}'
        assert (angles[:, :] <= 2).all(), f'{angles[angles[:, :] <= 2]}'

        if i % opt.save_interval == 0:  # Validation
            print()
            print(f'*** {i // opt.save_interval}-th validation:')
            validation_set = [angles, rs, vx1s, vx2s, vy]
            p_results, r_results, f_results, a_results = \
                validation_with_multi_params(*validation_set, sim_pms)
            
            if 'p' in opt.metrics:
                df_result = pd.DataFrame(p_results, index=sim_pms)
                print('***Precision:\n',df_result)
            if 'r' in opt.metrics:
                df_result = pd.DataFrame(r_results, index=sim_pms)
                print('***Recall   :\n',df_result)
            if 'f' in opt.metrics:
                df_result = pd.DataFrame(f_results, index=sim_pms)
                print('***F-value  :\n',df_result)
            if 'a' in opt.metrics:
                df_result = pd.DataFrame(a_results, index=sim_pms)
                print('***Accuracy :\n',df_result)
            if opt.output_log:
                f_log.write('iteration:'+str(i)+ '\n')
                df_result.to_csv(f_log, sep="\t")

            np.savez(os.path.join(opt.savedir, filepath) + '_epoch{}'.format(i),
                     rs, angles, vocab)

            cur_best_f = np.max(f_results)
            if cur_best_f > best_f:
                print(f'*** Update best model with {cur_best_f}')
                best_f = cur_best_f
                np.savez(os.path.join(opt.savedir, 'best'), rs, angles, vocab)

            # decay learning rage
            new_alpha = opt.alpha * opt.alpha_decay
            opt.alpha = max(new_alpha, opt.min_alpha)
            print('*** Current alpha: ', opt.alpha)
            print()

    print('finish')
    if opt.output_log:
        f_log.close()



if __name__ == '__main__':
    args = \
        argparse.ArgumentParser('Train polar embeddings with noun dataset',
                                add_help=True)

    args.add_argument('-t', help='Training file',
                      choices=[0, 10, 25, 50, 90], default=10, type=int)
    args.add_argument('-dim', help='Dimension',
                      default=5, type=int)
    args.add_argument('-iter', help='Number of iteration',
                      default=10000000, type=int)
    args.add_argument('-batch', help='Number of batch',
                      default=128, type=int)
    args.add_argument('-nratio', help='Ratio of negative samples',
                      default=1.0, type=float)
    args.add_argument('-alpha', help='Learning ratio',
                      default=0.5, type=float)
    args.add_argument('-min_alpha', help='Minimum learning ratio',
                      default=0.05, type=float)
    args.add_argument('-savedir', help='Model save directory',
                      default='./model')
    args.add_argument('-debug', help='Debug mode',
                      action='store_true')
    args.add_argument('-metrics', help='Metrics to show (p, r, f, a)',
                      nargs='+', default=['f'])
    args.add_argument('-save_interval', help='Interval to save model',
                      default=1000, type=int)
    args.add_argument('-alpha_decay', help='Decay of learning rate',
                      default=0.95, type=float)
    args.add_argument('-welsch_pm', help='welsch loss parameter c', default=0.4, type=float)
    args.add_argument('--mammal', help='Train model with mammal dataset',
                      action='store_true')
    args.add_argument('--dropout', help='Dropout in each dimension',action='store_true')
    args.add_argument('-dropout_weights', help='Weights of unactivated words', default=0.2, type=float)
    args.add_argument('-dropout_rate', help='Unactivated words ratio', default=0.4, type=float)
    args.add_argument('--SVGD', help='keep uniform distribution',
                      action='store_true')
    args.add_argument('-svgd_interval', help='Epoch interval to run SVGD',
                      default=20000, type=int)
    args.add_argument('-svgd_batch', help='Number of batch for SVGD',
                      default=128, type=int)
    args.add_argument('-svgd_epoch', help='Number of iteration for SVGD',
                      default=5, type=int)
    args.add_argument('-svgd_alpha', help='Learning rate for SVGD',
                      default=1, type=float)
    args.add_argument('-svgd_eo', help='Earlyout strategy for SVGD.',
                      default='validation', choices=['validation'])
    args.add_argument('-svgd_eo_vth', help='Threshold of validation earlyout for SVGD. (It allows 0.5% loss of validation accuracy in default)',
                      default=0.99, type=float)
    args.add_argument('-svgd_loss_min', help='Loss threshold for early stopping for SVGD',
                      default=10, type=float)
    args.add_argument('-svgd_validation_interval', help='Validation interval for SVGD',
                      default=500, type=int)
    args.add_argument('-svgd_sampling', help='Sampling method for SVGD',
                      default='bruteforce', choices=['uniform', 'multinomial', 'log_prob', 'bruteforce'])
    args.add_argument('-svgd_workers', help='Number of workers for SVGD (-1 means use all CPUs. 0 means single process.)',
                      default=-1, type=int)
    args.add_argument('--svgd_plot', help='Plot angles before and after SVGD',
                      action='store_true')
    args.add_argument('--output_log', help='Output log', action='store_true')
    args.add_argument('--l2loss', help='use l2 loss', action='store_true')
    args.add_argument('--rfile', help= 'use another r')


    # args.add_argument('-opt_words', help='Ratio of optimized words', default=1.0, type=float)

    opt = args.parse_args()

    if not os.path.exists(opt.savedir):
        os.mkdir(opt.savedir)

    main(opt)
