#!/usr/bin/env python
# coding: utf-8

# In[1]:


import logging

from pathlib import Path
# Set the seed value all over the place to make this reproducible.
seed_val = 42

# load the forward and backward generation dataset
def read_dataset_split(split_dir):
    split_dir = Path(split_dir)
    your_knowledge = []
    partner_knowledge=[]
    
    dialogue_history = []
    
    your_context = []
    both_context = []
    
    response = []
    
    your_inputs = []
    both_inputs = []
    persona_inputs = []
    
    with open(split_dir, 'r', encoding='UTF-8') as f:
        lines = f.readlines()
        counter = 0
        
        for line in lines:
            line = line.split('\n', 1)[0]
            if counter == 0:
                your_knowledge.append(line)
            if counter == 1:
                partner_knowledge.append(line)
            if counter == 2:
                dialogue_history.append(" <|endoftext|> . ".join(line.split(" . ")))
                your_context.append(line + " " + your_knowledge[-1])
                both_context.append(line + " " + your_knowledge[-1] + " " + partner_knowledge[-1])
            if counter == 3:
                your_inputs.append(your_context[-1] + " <|gen|> " + line)
                both_inputs.append(both_context[-1] + " <|gen|> " + line)
                persona_inputs.append(your_context[-1] + " <|gen|> " + partner_knowledge[-1])
                response.append(line)
                counter = 0
            else:
                counter = counter + 1
    return your_knowledge, partner_knowledge, dialogue_history, your_context, both_context, response, your_inputs, both_inputs, persona_inputs

train_your_knowledge, train_partner_knowledge, train_dialogue_history, train_your_context, train_both_context, train_response, train_your_inputs, train_both_inputs, train_persona_inputs = read_dataset_split('data/personachat/train_formatted.txt')
valid_your_knowledge, valid_partner_knowledge, valid_dialogue_history, valid_your_context, valid_both_context, valid_response, valid_your_inputs, valid_both_inputs, valid_persona_inputs = read_dataset_split('data/personachat/valid_formatted.txt')
test_your_knowledge, test_partner_knowledge, test_dialogue_history, test_your_context, test_both_context, test_response, test_your_inputs, test_both_inputs, test_persona_inputs = read_dataset_split('data/personachat/test_formatted.txt')


# In[ ]:





# In[2]:


from tqdm.auto import tqdm
# these following several blocks trains an answerability model...
from transformers import DistilBertTokenizerFast
classification_tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')


# In[3]:


import random
train_labels_classification = [1] * len(train_your_knowledge)
train_inputs_classification = []
for i in tqdm(range(len(train_your_knowledge))):
    train_inputs_classification.append((train_your_knowledge[i], train_response[i]))

train_augmentation = []
for i in tqdm(range(len(train_labels_classification))):
    while True:
        (A_k, A_r), (B_k, B_r) = random.sample(train_inputs_classification, 2)
        if not A_k == B_k:
            break
    train_labels_classification.append(0)
    train_inputs_classification.append((A_k, B_r))


test_labels_classification = [1] * len(test_your_knowledge)
test_inputs_classification = []
for i in tqdm(range(len(test_your_knowledge))):
    test_inputs_classification.append((test_your_knowledge[i], test_response[i]))

test_augmentation = []
for i in tqdm(range(len(test_labels_classification))):
    while True:
        (A_k, A_r), (B_k, B_r) = random.sample(test_inputs_classification, 2)
        if not A_k == B_k:
            break
    test_labels_classification.append(0)
    test_inputs_classification.append((A_k, B_r))


# In[4]:


train_encodings_classification = classification_tokenizer(train_inputs_classification, truncation=True, padding=True)
test_encodings_classification = classification_tokenizer(test_inputs_classification, truncation=True, padding=True)


# In[ ]:





# In[5]:


import torch
torch.cuda.set_device(1)
class ADUCDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset_classification = ADUCDataset(train_encodings_classification, train_labels_classification)
test_dataset_classification = ADUCDataset(test_encodings_classification, test_labels_classification)


# In[6]:


from torch.utils.data import DataLoader
from transformers import DistilBertForSequenceClassification, AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

cla_model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
cla_model.to(device)
cla_model.train()

train_dataloader_cla = DataLoader(train_dataset_classification, batch_size=2, shuffle=True)
test_dataloader_cla = DataLoader(test_dataset_classification, batch_size=2)

# default is 5e-5
cla_optimizer = AdamW(cla_model.parameters(), lr=5e-6)


from transformers import get_scheduler

num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader_cla)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=cla_optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


# In[7]:



device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cla_model.to(device)


# In[8]:


from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

cla_model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader_cla:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = cla_model(**batch)
        loss = outputs.loss
        loss.backward()

        cla_optimizer.step()
        lr_scheduler.step()
        cla_optimizer.zero_grad()
        progress_bar.update(1)


# In[9]:


import numpy as np
from datasets import load_metric

def cla_classifier(i=0, analyze=False):
    metric= load_metric("accuracy")
    cla_model.eval()
    base = 0
    if analyze:
        with open("data/analyzer" + str(i) + ".txt", 'a') as f:
            for batch in test_dataloader_cla:
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = cla_model(**batch)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)
                labels = batch["labels"]
                metric.add_batch(predictions=predictions, references=batch["labels"])
                if analyze:
                    if predictions[0]==0 and labels[0]==1:
                        f.write(str(base))
                        f.write('\n')
                    if predictions[1]==0 and labels[1]==1:
                        f.write(str(base+1))
                        f.write('\n')
                base = base + 2
    else:
        for batch in test_dataloader_cla:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = cla_model(**batch)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            labels = batch["labels"]
            metric.add_batch(predictions=predictions, references=batch["labels"])
    print(metric.compute())


# In[10]:


cla_classifier()


# In[11]:




# load general utilities, then initialize the tokenizer
import os
import time
import datetime

import pandas as pd
import seaborn as sns
import numpy as np
import random

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
#torch.manual_seed(42)

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup

logging.disable(logging.CRITICAL)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small", sep_token = "<|gen|>" ,bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') 
#print(tokenizer)


# In[12]:


batch_size = 1

class GPT2Dataset(Dataset):

    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=100):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []

        for txt in txt_list:

            encodings_dict = tokenizer('<|startoftext|> '+ txt + ' <|endoftext|>', truncation=True, max_length=max_length, padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]

dataset_persona = GPT2Dataset(train_persona_inputs, tokenizer, max_length=768)
dataset_both = GPT2Dataset(train_both_inputs, tokenizer, max_length=768)
dataset_your = GPT2Dataset(train_both_inputs, tokenizer, max_length=768)

# Split into training and validation sets
train_size = int(0.999 * len(dataset_persona))
val_size = len(dataset_persona) - train_size

train_dataset_persona, val_dataset_persona = random_split(dataset_persona, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
train_dataloader_persona = DataLoader(
            train_dataset_persona,  # The training samples.
            sampler = RandomSampler(train_dataset_persona), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader_persona = DataLoader(
            val_dataset_persona, # The validation samples.
            sampler = SequentialSampler(val_dataset_persona), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )


# Split into training and validation sets
train_size = int(0.999 * len(dataset_both))
val_size = len(dataset_both) - train_size

train_dataset_both, val_dataset_both = random_split(dataset_both, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
train_dataloader_both = DataLoader(
            train_dataset_both,  # The training samples.
            sampler = RandomSampler(train_dataset_both), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader_both = DataLoader(
            val_dataset_both, # The validation samples.
            sampler = SequentialSampler(val_dataset_both), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

# Split into training and validation sets
train_size = int(0.999 * len(dataset_your))
val_size = len(dataset_your) - train_size

train_dataset_your, val_dataset_your = random_split(dataset_your, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
train_dataloader_your = DataLoader(
            train_dataset_your,  # The training samples.
            sampler = RandomSampler(train_dataset_your), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader_your = DataLoader(
            val_dataset_your, # The validation samples.
            sampler = SequentialSampler(val_dataset_your), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )


# In[13]:


# instantiate the model
both_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
persona_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
mtl_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
your_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
#model.half()

# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
both_model.resize_token_embeddings(len(tokenizer))
persona_model.resize_token_embeddings(len(tokenizer))
mtl_model.resize_token_embeddings(len(tokenizer))
your_model.resize_token_embeddings(len(tokenizer))

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
mtl_model.cuda()
both_model.cuda()
persona_model.cuda()
your_model.cuda()


# In[14]:


# this block trains the persona model
epochs = 2
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
persona_optimizer = AdamW(persona_model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader_persona) * epochs

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
persona_scheduler = get_linear_schedule_with_warmup(persona_optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

total_t0 = time.time()

training_stats = []

persona_model = persona_model.to(device)

from tqdm.auto import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler

scaler = GradScaler()

for epoch_i in range(0, epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_train_loss = 0

    persona_model.train()
    
    progress_bar = tqdm(range(len(train_dataloader_persona)))
    
    for step, batch in enumerate(train_dataloader_persona):

        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        persona_model.zero_grad()   
        with autocast():
        
            outputs = persona_model(  b_input_ids,
                          labels=b_labels, 
                          attention_mask = b_masks,
                          token_type_ids=None
                        )

            loss = outputs[0]  

            batch_loss = loss.item()
            total_train_loss += batch_loss
            
        scaler.scale(loss).backward()
        #loss.backward()
        #with amp.scale_loss(loss, optimizer) as scaled_loss:
        #    scaled_loss.backward()
            
        #optimizer.step()
        scaler.step(persona_optimizer)
        scaler.update()
        persona_scheduler.step()
        #progress_bar.update(1)

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader_persona)       
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    persona_model.eval()

    total_eval_loss = 0
    nb_eval_steps = 0

    # Evaluate data for one epoch
    for batch in validation_dataloader_persona:
        
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)
        
        with torch.no_grad():        

            outputs  = persona_model (b_input_ids, 
#                            token_type_ids=None, 
                             attention_mask = b_masks,
                            labels=b_labels)
          
            loss = outputs[0]  
            
        batch_loss = loss.item()
        total_eval_loss += batch_loss        

    avg_val_loss = total_eval_loss / len(validation_dataloader_persona)
    
    validation_time = format_time(time.time() - t0)    

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))


# In[15]:


# this block trains the forward model

epochs = 2
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
both_optimizer = AdamW(both_model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader_both) * epochs

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
both_scheduler = get_linear_schedule_with_warmup(both_optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

total_t0 = time.time()

training_stats = []

both_model = both_model.to(device)

from tqdm.auto import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler

scaler = GradScaler()

for epoch_i in range(0, epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_train_loss = 0

    both_model.train()
    
    progress_bar = tqdm(range(len(train_dataloader_both)))
    
    for step, batch in enumerate(train_dataloader_both):

        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        both_model.zero_grad()   
        with autocast():
        
            outputs = both_model(  b_input_ids,
                          labels=b_labels, 
                          attention_mask = b_masks,
                          token_type_ids=None
                        )

            loss = outputs[0]  

            batch_loss = loss.item()
            total_train_loss += batch_loss
            
        scaler.scale(loss).backward()
        #loss.backward()
        #with amp.scale_loss(loss, optimizer) as scaled_loss:
        #    scaled_loss.backward()
            
        #optimizer.step()
        scaler.step(both_optimizer)
        scaler.update()
        both_scheduler.step()
        #progress_bar.update(1)

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader_both)       
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    both_model.eval()

    total_eval_loss = 0
    nb_eval_steps = 0

    # Evaluate data for one epoch
    for batch in validation_dataloader_both:
        
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)
        
        with torch.no_grad():        

            outputs  = both_model (b_input_ids, 
#                            token_type_ids=None, 
                             attention_mask = b_masks,
                            labels=b_labels)
          
            loss = outputs[0]  
            
        batch_loss = loss.item()
        total_eval_loss += batch_loss        

    avg_val_loss = total_eval_loss / len(validation_dataloader_both)
    
    validation_time = format_time(time.time() - t0)    

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))


# In[16]:


import re
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.nist_score import sentence_nist
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer

reference = [['this', 'is', 'a', 'test']]
candidate = ['this', 'is', 'a', 'test']
score = sentence_nist(reference, candidate, n=2)
print(score)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
scores = scorer.score('The quick brown fox jumps over the lazy dog',
                      'The quick brown dog jumps on the log.')
print(scores["rougeL"].fmeasure)

def process_s(s):
    #print(sentence)
    s = re.sub('([.,!?()])', r' \1 ', s)
    s = re.sub('\s{2,}', ' ', s).replace('<|endoftext|>','') + ' <|endoftext|>'
    return s

def process_GPT_output(s):
    #print(sentence)
    s = re.sub('([.,!?()])', r' \1 ', s)
    s = re.sub('\s{2,}', ' ', s)
    i = s.index("<|gen|>")
    s = s[i + 7:]
    s = s.replace('<|endoftext|>','')
    return s


# In[17]:


from itertools import chain


def pad_sequence(sequence, n, pad_left=False, pad_right=False,
                 left_pad_symbol=None, right_pad_symbol=None):
    """
    Returns a padded sequence of items before ngram extraction.
        >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
        ['<s>', 1, 2, 3, 4, 5, '</s>']
        >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
        ['<s>', 1, 2, 3, 4, 5]
        >>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
        [1, 2, 3, 4, 5, '</s>']
    :param sequence: the source data to be padded
    :type sequence: sequence or iter
    :param n: the degree of the ngrams
    :type n: int
    :param pad_left: whether the ngrams should be left-padded
    :type pad_left: bool
    :param pad_right: whether the ngrams should be right-padded
    :type pad_right: bool
    :param left_pad_symbol: the symbol to use for left padding (default is None)
    :type left_pad_symbol: any
    :param right_pad_symbol: the symbol to use for right padding (default is None)
    :type right_pad_symbol: any
    :rtype: sequence or iter
    """
    sequence = iter(sequence)
    if pad_left:
        sequence = chain((left_pad_symbol,) * (n - 1), sequence)
    if pad_right:
        sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
    return sequence


def ngrams(sequence, n, pad_left=False, pad_right=False,
           left_pad_symbol=None, right_pad_symbol=None):
    """
    Return the ngrams generated from a sequence of items, as an iterator.
    For example:
        >>> from nltk.util import ngrams
        >>> list(ngrams([1,2,3,4,5], 3))
        [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
    Wrap with list for a list version of this function.  Set pad_left
    or pad_right to true in order to get additional ngrams:
        >>> list(ngrams([1,2,3,4,5], 2, pad_right=True))
        [(1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
        >>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
        [(1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
        >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
        [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5)]
        >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
        [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
    :param sequence: the source data to be converted into ngrams
    :type sequence: sequence or iter
    :param n: the degree of the ngrams
    :type n: int
    :param pad_left: whether the ngrams should be left-padded
    :type pad_left: bool
    :param pad_right: whether the ngrams should be right-padded
    :type pad_right: bool
    :param left_pad_symbol: the symbol to use for left padding (default is None)
    :type left_pad_symbol: any
    :param right_pad_symbol: the symbol to use for right padding (default is None)
    :type right_pad_symbol: any
    :rtype: sequence or iter
    """
    sequence = pad_sequence(sequence, n, pad_left, pad_right,
                            left_pad_symbol, right_pad_symbol)

    history = []
    while n > 1:
        history.append(next(sequence))
        n -= 1
    for item in sequence:
        history.append(item)
        yield tuple(history)
        del history[0]
        
def distinct_n_sentence_level(sentence, n):
    """
    Compute distinct-N for a single sentence.
    :param sentence: a list of words.
    :param n: int, ngram.
    :return: float, the metric value.
    """
    sentences = sentence.split("<sep>")
    total_len = 0
    ds = set()
    for s in sentences:
        low = s.split(" ")
        total_len = total_len + len(low)
        distinct_ngrams = set(ngrams(low, n, pad_right=True))
        ds = ds.union(distinct_ngrams)
            
    if len(sentence) == 0:
        return 0.0  # Prevent a zero division

    return len(ds) * 1.0 / total_len


# In[18]:


import warnings
import sys
warnings.filterwarnings("ignore")

def test_pipeline():
    both_model.eval()
    persona_model.eval()
    rouge_scores=[]
    meteor_scores=[]
    bleu_scores = []
    bleu1_scores = []
    bleu2_scores = []
    bleu3_scores = []
    bleu4_scores = []
    nist_scores=[]
    distinct_output=""
    test_len = len(test_your_context)
    
    for i in tqdm(range(test_len)):
        # produce partner's personality first
        single_context = "<|startoftext|> " + test_your_context[i] + " <|gen|>"
        single_response = test_response[i]
        
        generated = torch.tensor(tokenizer.encode(single_context,truncation=True, max_length=768)).unsqueeze(0)
        generated = generated.to(device)
        o = persona_model.generate(generated, 
                                   do_sample=False,   
                                   top_k=30, 
                                   max_length = 768,
                                   min_length = 2,
                                   top_p=0.95, 
                                   num_return_sequences=1)
        partner_personality = tokenizer.decode(o[0], skip_special_tokens=False)
        if " <|gen|> " not in partner_personality:
            continue
        partner_personality = process_s(partner_personality)
        partner_personality = process_GPT_output(partner_personality).lstrip()
        
        single_context = "<|startoftext|> " + test_your_context[i] + " " + partner_personality +  " <|gen|>"
        generated = torch.tensor(tokenizer.encode(single_context,truncation=True, max_length=768)).unsqueeze(0)
        generated = generated.to(device)
        o = both_model.generate(generated, 
                                do_sample=False,   
                                top_k=30, 
                                max_length = 768,
                                min_length = 2,
                                top_p=0.95, 
                                num_return_sequences=1)
        
        final_output_full = tokenizer.decode(o[0], skip_special_tokens=False)
        if " <|gen|> " not in final_output_full:
            continue
        final_output = process_s(final_output_full)
        final_output = process_GPT_output(final_output).lstrip()
        distinct_output = distinct_output + " <sep> " + final_output
        
        reference = [single_response.split(" ")]
        candidate = final_output.split(" ")
        
        bleu = sentence_bleu(reference, candidate)
        bleu1 = sentence_bleu(reference, candidate, weights=(1, 0, 0, 0))
        bleu2 = sentence_bleu(reference, candidate, weights=(0, 1, 0, 0))
        bleu3 = sentence_bleu(reference, candidate, weights=(0, 0, 1, 0))
        bleu4 = sentence_bleu(reference, candidate, weights=(0, 0, 0, 1))
        nist = sentence_nist(reference, candidate, n=1)
        
        rouge_score = scorer.score(single_response, final_output)
        meteor = meteor_score(reference, candidate)
        rouge_scores.append(rouge_score["rougeL"].fmeasure)
        meteor_scores.append(meteor)
        bleu_scores.append(bleu)
        bleu1_scores.append(bleu1)
        bleu2_scores.append(bleu2)
        bleu3_scores.append(bleu3)
        bleu4_scores.append(bleu4)
        nist_scores.append(nist)
  

    print("Pipeline Test: Running avg for BLEU pos is {}".format(np.mean(bleu_scores)))
    print("Pipeline Test: Running avg for BLEU1 pos is {}".format(np.mean(bleu1_scores)))
    print("Pipeline Test: Running avg for BLEU2 pos is {}".format(np.mean(bleu2_scores)))
    print("Pipeline Test: Running avg for BLEU3 pos is {}".format(np.mean(bleu3_scores)))
    print("Pipeline Test: Running avg for BLEU4 pos is {}".format(np.mean(bleu4_scores)))
    print("Pipeline Test: Running avg for NIST is {}".format(np.mean(nist_scores)))
    print("Pipeline Test: AVG for ROUGE is {}".format(np.mean(rouge_scores)))
    print("Pipeline Test: AVG for Meteor is {}".format(np.mean(meteor_scores)))
    print("Pipeline Test: AVG for distinct 1 is {}".format(distinct_n_sentence_level(distinct_output,1)))
    print("Pipeline Test: AVG for distinct 2 is {}".format(distinct_n_sentence_level(distinct_output,2)))
                


# In[19]:





# In[20]:


learning_rate = 5e-6
warmup_steps = 1e1
epsilon = 1e-8

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
persona_optimizer = AdamW(persona_model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = 10000

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
persona_scheduler = get_linear_schedule_with_warmup(persona_optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)


# In[21]:


learning_rate = 5e-6
warmup_steps = 1e1
epsilon = 1e-8

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
both_optimizer = AdamW(both_model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = 10000

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
both_scheduler = get_linear_schedule_with_warmup(both_optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)


# In[22]:





# In[23]:





# In[24]:


test_pipeline()


# In[ ]:





# In[ ]:


rouge_scores=[]
meteor_scores=[]
cla_model.eval()
test_len =  len(train_your_context)
btsz=20
counter = 0
counter_2 = 0
eval_btsz = 1000
softmax = torch.nn.Softmax(dim=0)
persona_model.train()
both_model.train()
for i in tqdm(range(test_len)):
        # produce partner's personality first
    single_context = "<|startoftext|> " + train_your_context[i] + " <|gen|>"
    single_response = train_response[i]
        
    generated = torch.tensor(tokenizer.encode(single_context,truncation=True, max_length=768)).unsqueeze(0)
    generated = generated.to(device)
    o = persona_model.generate(generated, 
                                   do_sample=False,   
                                   top_k=30, 
                                   max_length = 768,
                                   min_length = 2,
                                   top_p=0.95, 
                                   num_return_sequences=1)
    partner_personality = tokenizer.decode(o[0], skip_special_tokens=False)
    partner_personality = process_s(partner_personality)
    partner_personality = process_GPT_output(partner_personality).lstrip()
    
    final_output_full = single_context + partner_personality
    total=len(final_output_full.split(" "))
    context=len(single_context.split(" "))
    
    ppl_input = torch.tensor(tokenizer.encode(final_output_full,truncation=True,max_length=768)).unsqueeze(0)
    ppl_input = ppl_input.to(device)
            
    ppl_desire = torch.tensor(tokenizer.encode(final_output_full,truncation=True,max_length=768)).unsqueeze(0)
    ppl_desire = ppl_desire.to(device)
    ppl_desire[:, :-(total-context)-1] = -100 
          
    outputs = persona_model(ppl_input, labels=ppl_desire)
    loss = outputs[0]
        
    single_context = "<|startoftext|> " + train_your_context[i] + " " + partner_personality +  " <|gen|>"
    generated = torch.tensor(tokenizer.encode(single_context,truncation=True, max_length=768)).unsqueeze(0)
    generated = generated.to(device)
    o = both_model.generate(generated, 
                                do_sample=False,   
                                top_k=30, 
                                max_length = 768,
                                min_length = 2,
                                top_p=0.95, 
                                num_return_sequences=1)
        
    final_output_full = tokenizer.decode(o[0], skip_special_tokens=False)
    final_output = process_s(final_output_full)
    final_output = process_GPT_output(final_output).lstrip()
    if len((partner_personality + " [SEP] " + final_output).split(" ")) > 500:
        continue
    cla_generated = torch.tensor(classification_tokenizer.encode(partner_personality + " [SEP] " + final_output)).unsqueeze(0)
    cla_generated = cla_generated.to(device)
    smx = softmax(cla_model(cla_generated).logits[0])
    orobi = torch.argmax(smx)
    if orobi == 0:
        reward = -1
    else:
        reward = 1
        
    final_output_full = single_context + " "+ single_response
    total=len(final_output_full.split(" "))
    context=len(single_context.split(" "))
        
    ppl_input = torch.tensor(tokenizer.encode(final_output_full,truncation=True,max_length=768)).unsqueeze(0)
    ppl_input = ppl_input.to(device)
            
    ppl_desire = torch.tensor(tokenizer.encode(final_output_full,truncation=True,max_length=768)).unsqueeze(0)
    ppl_desire = ppl_desire.to(device)
    ppl_desire[:, :-(total-context)-1] = -100 
    
    outputs = both_model(ppl_input, labels=ppl_desire)
    loss = loss + outputs[0]
    
    ppl = outputs[0].cpu().detach().numpy()
    ppl_scores.append(ppl)   
        
    loss = loss * reward
    if counter == btsz:
        scaler.scale(loss).backward()
        scaler.step(persona_optimizer)
        scaler.step(both_optimizer)
        scaler.update()
        persona_scheduler.step()
        both_scheduler.step()
        loss = 0
        counter = 0
        
    if counter_2 == eval_btsz:

        test_pipeline()
        counter_2 = 0
        persona_model.train()
        both_model.train()
    
    counter = counter + 1
    counter_2 = counter_2 + 1


# In[ ]:


# In[ ]:




