from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, BertTokenizer
from tqdm import tqdm
import pickle

kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])
device = accelerator.device


class RetrieverModel(nn.Module):
    def __init__(self):
        super(RetrieverModel, self).__init__()
        self.dialogue_encoder = AutoModel.from_pretrained('bert-base-uncased')
        self.memory_encoder = AutoModel.from_pretrained('bert-base-uncased')

    def forward(self, dialogue, positive_memory, negative_memory):
        dialogue_encoding = self.dialogue_encoder(input_ids=dialogue[0], attention_mask=dialogue[1])[0][:, 0, :]
        positive_memory_encoding = self.memory_encoder(input_ids=positive_memory[0], attention_mask=positive_memory[1])[
                                       0][:, 0, :]
        negative_memory_encoding = self.memory_encoder(input_ids=negative_memory[0], attention_mask=negative_memory[1])[
                                       0][:, 0, :]

        return dialogue_encoding, positive_memory_encoding, negative_memory_encoding


class TripletDataset(Dataset):
    def __init__(self, dialogues, positive_memories, negative_memories, tokenizer):
        self.dialogues = dialogues
        self.positive_memories = positive_memories
        self.negative_memories = negative_memories
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        dialogue = self.dialogues[idx]
        positive_memory = self.positive_memories[idx]
        negative_memory = self.negative_memories[idx]

        anchor_encoding = self.tokenizer.encode_plus(
            dialogue,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        positive_encoding = self.tokenizer.encode_plus(
            positive_memory,
            max_length=32,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        negative_encoding = self.tokenizer.encode_plus(
            negative_memory,
            max_length=32,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'anchor_ids': anchor_encoding['input_ids'].squeeze(),
            'anchor_attention_mask': anchor_encoding['attention_mask'].squeeze(),
            'positive_ids': positive_encoding['input_ids'].squeeze(),
            'positive_attention_mask': positive_encoding['attention_mask'].squeeze(),
            'negative_ids': negative_encoding['input_ids'].squeeze(),
            'negative_attention_mask': negative_encoding['attention_mask'].squeeze(),
        }


class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = nn.functional.cosine_similarity(anchor, positive)
        distance_negative = nn.functional.cosine_similarity(anchor, negative)
        losses = nn.functional.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()


def main():
    with open('TRAIN_DATA_PATH', 'rb') as f:
        train = pickle.load(f)
    train_dialogues = train[0]
    train_positive_memories = train[1]
    train_negative_memories = train[2]

    with open('VALID_DATA_PATH', 'rb') as f:
        valid = pickle.load(f)
    valid_dialogues = valid[0]
    valid_positive_memories = valid[1]
    valid_negative_memories = valid[2]

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = RetrieverModel().to(device)

    train_dataset = TripletDataset(train_dialogues, train_positive_memories, train_negative_memories, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=90, shuffle=True)

    valid_dataset = TripletDataset(valid_dialogues, valid_positive_memories, valid_negative_memories, tokenizer)
    valid_loader = DataLoader(valid_dataset, batch_size=90, shuffle=False)

    criterion = TripletLoss()
    dialogue_optimizer = torch.optim.Adam(model.dialogue_encoder.parameters(), lr=1e-4)
    memory_optimizer = torch.optim.Adam(model.memory_encoder.parameters(), lr=1e-4)

    model, dialogue_optimizer, memory_optimizer, train_loader, valid_loader = accelerator.prepare(
        model, dialogue_optimizer, memory_optimizer, train_loader, valid_loader)

    num_epochs = 20
    best_loss = float('inf')
    early_stop_counter = 0
    early_stopping_epochs = 5
    for epoch in tqdm(range(num_epochs)):
        train_loss = 0
        model.train()
        for step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
            with accelerator.accumulate(model):
                anchor = (batch['anchor_ids'].to(device), batch["anchor_attention_mask"].to(device))
                positive = (batch['positive_ids'].to(device), batch["positive_attention_mask"].to(device))
                negative = (batch['negative_ids'].to(device), batch["negative_attention_mask"].to(device))

                anchor_outputs, positive_outputs, negative_outputs = model(anchor, positive, negative)

                loss = criterion(anchor_outputs, positive_outputs, negative_outputs)
                train_loss += loss.item()
                accelerator.backward(loss)
                dialogue_optimizer.step()
                memory_optimizer.step()
                dialogue_optimizer.zero_grad()
                memory_optimizer.zero_grad()

        if accelerator.is_main_process:
            train_loss = train_loss / len(train_loader)
            print('Train Loss: {:.4f}'.format(train_loss))

        model.eval()
        valid_loss = 0
        for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            with accelerator.accumulate(model):
                with torch.no_grad():
                    anchor = (batch['anchor_ids'].to(device), batch["anchor_attention_mask"].to(device))
                    positive = (batch['positive_ids'].to(device), batch["positive_attention_mask"].to(device))
                    negative = (batch['negative_ids'].to(device), batch["negative_attention_mask"].to(device))

                    anchor_outputs, positive_outputs, negative_outputs = model(anchor, positive, negative)

                    loss = criterion(anchor_outputs, positive_outputs, negative_outputs)
                    valid_loss += loss.item()

        if accelerator.is_main_process:
            valid_loss = valid_loss / len(valid_loader)
            print('Valid Loss: {:.4f}'.format(valid_loss))

        if accelerator.is_main_process:
            if valid_loss > best_loss:
                early_stop_counter += 1
            else:
                best_loss = valid_loss
                early_stop_counter = 0
                accelerator.save_model(model, 'SAVE_PATH')

            if early_stop_counter >= early_stopping_epochs:
                print("Early Stopping!")
                break


if __name__ == '__main__':
    main()
