import torch
import lightning as L
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import CSVLogger
from typing import List, Tuple
from lightning.pytorch.tuner import Tuner
from sklearn.metrics.pairwise import cosine_similarity
import argparse


class CooccuranceDataset(Dataset):
    def __init__(self, data: List[Tuple]):
        self.data = data

    def __getitem__(self, idx):
        c1_idx, c2_idx, y = self.data[idx]
        return int(c1_idx), int(c2_idx), torch.tensor(y).float()

    def __len__(self):
        return len(self.data)


class CooccurrenceDataModule(L.LightningDataModule):
    def __init__(self, mat: pd.DataFrame, bs=128, logy=True):
        super().__init__()
        if mat.shape[0] != mat.shape[1]:
            raise ValueError('Expected parameter "mat" to have the same number of rows and columns.')
        self.communities = mat.columns
        self.c2id = {s:i for i, s in enumerate(self.communities)}
        self.id2c = {i:s for s, i in self.c2id.items()}
        data = mat.where(~np.triu(np.ones(mat.shape)).astype(bool)).stack().reset_index()
        data.columns = ['c1', 'c2', 'count']
        data['count'] = data['count'].astype('float32')
        data['c1'] = data['c1'].map(self.c2id).astype('uint64')
        data['c2'] = data['c2'].map(self.c2id).astype('uint64')
        self.data = data
        self.bs = bs

    def setup(self, stage: str):
        self.train = CooccuranceDataset(self.data.values)
        self.val = CooccuranceDataset(self.data.values)
        self.test = CooccuranceDataset(self.data.values)

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.bs,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.bs,
            shuffle=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.bs,
            shuffle=True,
        )


class GloveModel(L.LightningModule):
    def __init__(self, num_communities, emb_dim=100, lr=1e-5, xmax=100, alpha=0.75):
        super().__init__()
        self.num_communities, self.emb_dim = num_communities, emb_dim
        self.emb = torch.nn.Embedding(num_communities, emb_dim)
        self.bias = torch.nn.Embedding(num_communities, 1)
        self.loss_fn = torch.nn.MSELoss()
        self.lr = lr
        self.xmax, self.alpha = xmax, alpha

    def forward(self, batch):
        c1, c2, y = batch
        emb_prod = (self.emb(c1) * self.emb(c2)).sum(dim=1)
        return emb_prod + self.bias(c1).flatten() + self.bias(c2).flatten()

    def weight_fn(self, x):
        weight = ((x / self.xmax) ** self.alpha)
        weight[x < self.xmax] = 1.0
        return weight

    def forward_and_calc_loss(self, batch):
        c1, c2, y = batch
        out = self.forward(batch)
        weight = self.weight_fn(y)
        loss = (weight * ( (out - torch.log(y)) ** 2 )).sum()
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward_and_calc_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.forward_and_calc_loss(batch)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    

def calculate_community_similarity(model, dm, normalize=True):
    weights = model.emb.weight.detach().cpu()
    if normalize:
        weights = torch.nn.functional.normalize(weights, dim=1)
    sims = pd.DataFrame(cosine_similarity(weights))
    sims.index = dm.communities
    sims.columns = dm.communities
    return sims

def calculate_group_similarity(sim, group1, group2, top_n=5):
    group1, group2 = set(group1), set(group2)
    group_subs = group1.union(group2)
    res = {}
    for sub in group_subs:
        sub_sim = sim.loc[sub].sort_values(ascending=False)
        sub_group = group1 if sub in group1 else group2  # only subs in group_subs will be iterated over
        opp_group = group2 if sub in group1 else group1
        closest_n_subs, furthest_n_subs = sub_sim.iloc[:top_n].index, sub_sim.iloc[-top_n:].index
        sub_res = {"in_group_count_top" : len(set(closest_n_subs).intersection(sub_group)), 
         "in_group_count_bottom" : len(set(furthest_n_subs).intersection(opp_group))}
        sub_res["in_group_count_total"] = sub_res["in_group_count_top"] + sub_res["in_group_count_bottom"]
        sub_res["pct_in_group_total"] = sub_res["in_group_count_total"] / (2 * top_n)
        sub_res["pct_in_group_top"] = sub_res["in_group_count_top"] / top_n
        sub_res["pct_in_group_bottom"] = sub_res["in_group_count_bottom"] / top_n
        res[sub] = sub_res
    res = pd.DataFrame(res).transpose()
    return res[["pct_in_group_total", "pct_in_group_top", "pct_in_group_bottom"]].mean()


parser = argparse.ArgumentParser()
parser.add_argument("--cooccur_mat_json", type=str, required=True, help="Path to the user-subreddit cooccurrence matrix.")
parser.add_argument("--out_path", type=str, required=True, help="Path to the json file to save the output in.")
parser.add_argument("--batch_size", type=int, default=8192, help="Batch size.")
parser.add_argument("--max_epochs", type=int, default=20, help="Number of epochs to train the model for.")
parser.add_argument("--accelerator", type=str, default="gpu", help="Machine to train the model on. gpu or cpu.")
parser.add_argument("--find_lr", type=int, default=0, help="If not 0, suggestes a learning rate to train the model.")
parser.add_argument("--lr", type=float, default=0, help="Learning rate.")
args = parser.parse_args()

mat = pd.read_json(args.cooccur_mat_json)

dm = CooccurrenceDataModule(mat, bs=args.batch_size)
dm.setup(None)
model = GloveModel(mat.shape[0], lr=args.lr)
logger = CSVLogger("logs", name="glove_log")
trainer = L.Trainer(
    max_epochs=args.max_epochs,
    accelerator=args.accelerator,
    devices=1,
    logger=logger,
    check_val_every_n_epoch=100,
)
if args.find_lr:
    tuner = Tuner(trainer)
    lr_finder = tuner.lr_find(model, datamodule=dm)
    print(lr_finder.results)
    fig = lr_finder.plot(suggest=True)
    fig.show()
    print("Suggested learning rate: " + str(lr_finder.suggestion()))
    exit()

trainer.fit(model, dm)

sim = calculate_community_similarity(model, dm)
sim.to_json(args.out_path)