import torch
import re
import numpy as np
import random
import json
import os
from transformers import RobertaTokenizer, RobertaModel
from transformers import BertTokenizer, BertModel
from torch.nn import CrossEntropyLoss
from transformers import GPT2Tokenizer, GPT2Model
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score
import logging

ROOT_PATH = ''
WINOWHY_PATH = 'datasets/winowhy.jsonl'
WINOLOGIC_PATH = 'datasets/winologic_variable.jsonl'
SAVE_DIR = 'models'
CACHE_DIR = ''
LOG_DIR = 'logs'

# helper function: read and dump data
def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        for line in data:
            json_record = json.dumps(line, ensure_ascii=False)
            f.write(json_record + '\n')
    logging.critical('Wrote {} records to {}'.format(len(data), output_path))

def load_jsonl(input_path) -> list:
    """
    Read list of objects from a JSON lines file.
    """
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.rstrip('\n|\r')))
    logging.critical('Loaded {} records from {}'.format(len(data), input_path))
    return data

from dataclasses import dataclass

@dataclass
class WinoWhySentence(object):
    sentence: str = None
    context: str = None
    wsc_sentence: str = None
    answer_reason: str = None
    reason: str = None
    label: int = 0
    wsc_id: int = 0
    fold_num: int = 1
        
        
@dataclass
class WinoLogicSentence(object):
    sentence: str = None
    context: str = None
    wsc_marked_sentence: str = None
    wsc_sentence: str = None
    answer_knowledge: str = None
    knowledge: str = None
    label: int = 0
    wsc_id: int = 0
    fold_num: int = 1
    valid: int = 1
    
def load_winowhy_from_path(filepath: str):
    ws = load_jsonl(filepath)
    winowhy_sentences = list()
    for w in ws:
        s = WinoWhySentence()
        s.sentence = w['sentence']
        s.context = w['context']
        s.wsc_sentence = w['wsc_sentence']
        s.answer_reason = w['answer_reason']
        s.reason = w['reason']
        s.label = w['label']
        s.wsc_id = w['wsc_id']
        s.fold_num = w['fold_num']
        winowhy_sentences.append(s)
    return winowhy_sentences
        
def load_winologic_from_path(filepath: str):
    ws = load_jsonl(filepath)
    winologic_sentences = list()
    for w in ws:
        s = WinoLogicSentence()
        s.sentence = w['sentence']
        s.context = w['context']
        s.wsc_marked_sentence = w['wsc_marked_sentence']
        s.wsc_sentence = w['wsc_sentence']
        s.answer_knowledge = w['answer_knowledge']
        s.knowledge = w['knowledge']
        s.label = w['label']
        s.wsc_id = w['wsc_id']
        s.fold_num = w['fold_num'] - 1
        s.valid = w['valid']
        winologic_sentences.append(s)
    return winologic_sentences

@dataclass
class ExpConfig(object):
    # WinoWhy dataset path
    training_set_path: str = ""
    # WinoLogic dataset path
    test_set_path: str = ""
    # Task description
    task_name: str = ""
    # Only using single GPU
    gpu_id: int = 0
    # Seed for random
    seed: int = 42
    # 'cpu', 'cuda'
    device: str = 'cpu' 
    # "gpt2", "gpt2-large", "bert-base-uncased", "bert-large-uncased", "roberta-base", "roberta-largbe"
    model_name: str = ""
    # If model_path is not None or not empty, load model from model_path instead of transformers' pretrained ones
    model_path: str = ""
    # For training the classifier layer
    learning_rate: float = 1e-3
    # Number of total epochs
    num_training_epochs: int = 15
    # Max sequence length
    max_seq_len: int = 128

    def set_seed(self, new_seed = None):
        seed = self.seed if new_seed is None else new_seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def set_gpu_if_possible(self, gpu_id = None):
        if torch.cuda.is_available():
            self.device = 'cuda'
            if gpu_id is not None:
                self.device = 'cuda:{}'.format(gpu_id)
        else:
            self.device = 'cpu'

class GPT2ForWinoLogic(GPT2Model):
    def __init__(self, config):
        super(GPT2Model, self).__init__(config)
        self.transformer = GPT2Model(config)
        self.hidden_dim = 200
        self.dropout = torch.nn.Dropout(0.5)
        self.second_last_layer = torch.nn.Linear(config.n_embd*2, self.hidden_dim)
        self.last_layer = torch.nn.Linear(self.hidden_dim, 2)

    def forward(self, first, second):
        """
        first: input ids for the WSC sentence.
        second: input ids for the Answer Knowledge (Answer Reason for WinoWhy)
        """
        # Get the hidden states of the last layer
        sent1_hidden_states = self.transformer(first)[0]
        sent2_hidden_states = self.transformer(second)[0]

        # mean pooling then concatnation
        overall_representation = torch.cat(
            [torch.mean(sent1_hidden_states.squeeze(), dim=0).unsqueeze(0),
             torch.mean(sent2_hidden_states.squeeze(), dim=0).unsqueeze(0)], dim=1)
        
        # pass two linear layers
        prediction = self.last_layer(self.second_last_layer(overall_representation))
        return prediction
    
class BertForWinoLogic(BertModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.bert = BertModel(config)
        self.hidden_dim = 200
        self.dropout = torch.nn.Dropout(0.5)
        self.second_last_layer = torch.nn.Linear(config.hidden_size*2, self.hidden_dim)
        self.last_layer = torch.nn.Linear(self.hidden_dim, 2)

    def forward(self, first, second):
        """
        first: input ids for the WSC sentence.
        second: input ids for the Answer Knowledge (Answer Reason for WinoWhy)
        """
        # Get the hidden states of the last layer
        sent1_hidden_states = self.bert(first)[0]
        sent2_hidden_states = self.bert(second)[0]

        # mean pooling then concatnation
        overall_representation = torch.cat(
            [torch.mean(sent1_hidden_states.squeeze(), dim=0).unsqueeze(0),
             torch.mean(sent2_hidden_states.squeeze(), dim=0).unsqueeze(0)], dim=1)
        
        # pass two linear layers
        prediction = self.last_layer(self.second_last_layer(overall_representation))
        return prediction
    
class RobertaForWinoLogic(RobertaModel):
    def __init__(self, config):
        super(RobertaModel, self).__init__(config)
        self.roberta = RobertaModel(config)
        self.hidden_dim = 200
        self.dropout = torch.nn.Dropout(0.5)
        self.second_last_layer = torch.nn.Linear(config.hidden_size*2, self.hidden_dim)
        self.last_layer = torch.nn.Linear(self.hidden_dim, 2)

    def forward(self, first, second):
        """
        first: input ids for the WSC sentence.
        second: input ids for the Answer Knowledge (Answer Reason for WinoWhy)
        """
        # Get the hidden states of the last layer
        sent1_hidden_states = self.roberta(first)[0]
        sent2_hidden_states = self.roberta(second)[0]

        # mean pooling then concatnation
        overall_representation = torch.cat(
            [torch.mean(sent1_hidden_states.squeeze(), dim=0).unsqueeze(0),
             torch.mean(sent2_hidden_states.squeeze(), dim=0).unsqueeze(0)], dim=1)
        
        # pass two linear layers
        prediction = self.last_layer(self.second_last_layer(overall_representation))
        return prediction

def train(model, tokenizer, training_set, config, loss_fct, optimizer):
    """
    Train the model on the training_set, return the average loss for this epoch.
    """
    all_losses = 0
    model.train()
    # training
    random.shuffle(training_set)
    for (i, example) in enumerate(training_set):
        if i % 500 == 0:
            logging.critical('Processing {}'.format(i))
        # prepare inputs
        first = example['first'] # WSC sentence
        second = example['second'] # Answer Reason
        label = example['label']
        
        if config.model_name.startswith('gpt2'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_special_tokens=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_special_tokens=True)).unsqueeze(0)
        elif config.model_name.startswith('bert'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_special_tokens=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_special_tokens=True)).unsqueeze(0)
        elif config.model_name.startswith('roberta'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_prefix_space=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_prefix_space=True)).unsqueeze(0)
        else:
            raise ValueError('Model name is not recognized: {}'.format(config.model_name))

        first_input_ids = first_input_ids.to(config.device)
        second_input_ids = second_input_ids.to(config.device)
        
        # get prediction
        prediction = model(first=first_input_ids, second=second_input_ids)
        
        # compute the loss
        loss = loss_fct(prediction, torch.tensor([label]).to(config.device))
        
        # back-prop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        all_losses += loss.item()
    average_loss = all_losses / len(training_set)
    return average_loss

def test(model, tokenizer, test_set, config):
    """
    Test the model on the test_set, and return the accuracies
    """
    # eval mode
    model.eval()
    
    n_correct = 0
    for (i, example) in enumerate(test_set):
        # prepare inputs
        first = example['first'] # WSC sentence
        second = example['second'] # Answer Reason
        label = example['label']
        
        if config.model_name.startswith('gpt2'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_special_tokens=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_special_tokens=True)).unsqueeze(0)
        elif config.model_name.startswith('bert'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_special_tokens=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_special_tokens=True)).unsqueeze(0)
        elif config.model_name.startswith('roberta'):
            first_input_ids = torch.tensor(tokenizer.encode(first, add_prefix_space=True)).unsqueeze(0)
            second_input_ids = torch.tensor(tokenizer.encode(second, add_prefix_space=True)).unsqueeze(0)
        else:
            raise ValueError('Model name is not recognized: {}'.format(config.model_name))

        first_input_ids = first_input_ids.to(config.device)
        second_input_ids = second_input_ids.to(config.device)
        
        # get prediction
        with torch.no_grad():
            output = model(first=first_input_ids, second=second_input_ids)
        pred_label = output.argmax()
        if pred_label == label:
            n_correct += 1
        
    acc = n_correct / len(test_set)
    return acc

def train_and_test(config: ExpConfig):
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # training phase    
    logging.critical('Experiment: {} using {} with device {}'.format(config.task_name, config.model_name, config.device))
    # preparing the training set
    training_set = list()
    ww_sentences = load_winowhy_from_path(config.training_set_path)
    for ws in ww_sentences:
        example = dict()
        example['first'] = ws.wsc_sentence
        example['second'] = ws.answer_reason
        example['label'] = ws.label
        training_set.append(example)
        
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # set up model and tokenizer according to the config
    model_name = config.model_name
    # Load model from model_path if specified
    if config.model_path is not None and config.model_path != "":
        model_name = config.model_path
    
    if config.model_name.startswith('gpt2'):
        model = GPT2ForWinoLogic.from_pretrained(model_name)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    elif config.model_name.startswith('bert'):
        model = BertForWinoLogic.from_pretrained(model_name)
        tokenizer = BertTokenizer.from_pretrained(model_name)
    elif config.model_name.startswith('roberta'):
        model = RobertaForWinoLogic.from_pretrained(model_name)
        tokenizer = RobertaTokenizer.from_pretrained(model_name)
    else:
        raise ValueError('Model name is not recognized: {}'.format(config.model_name))
    
    # put model to training mode
    model.to(config.device)
    model.train()
    
    # init loss funct and optimizer
    loss_fct = torch.nn.CrossEntropyLoss()
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
    
    # training loop
    logging.critical('Training on WinoWhy')
    all_losses = list()
    logging.critical('train size: {}'.format(len(training_set)))
    for i in range(config.num_training_epochs):
        logging.critical('Epoch={}'.format(i))
        avg_loss = train(model, tokenizer, training_set, config, loss_fct, optimizer)
        logging.critical('Average loss = {}'.format(avg_loss))
        all_losses.append(avg_loss)
    logging.critical('All losses: {}'.format(all_losses))

    # eval phase: on WinoWhy
    # preparing the test set
    logging.critical('Evaluating on WinoWhy')
    acc1 = test(model, tokenizer, training_set, config)
    logging.critical('Accuracy of {} on WinoWhy: {}'.format(config.model_name, acc1))
    
    
    # testing phase: on WinoLogic
    test_all_set = list()
    test_valid_set = list()
    winologic_sentences = load_winologic_from_path(config.test_set_path)
        
    for ws in winologic_sentences:
        example = dict()
        example['first'] = ws.wsc_sentence
        example['second'] = ws.answer_knowledge
        example['label'] = ws.label
        test_all_set.append(example)
        
    logging.critical('Testing on all examples of WinoLogic')
    logging.critical('test all set size: {}'.format(len(test_all_set)))
    acc2 = test(model, tokenizer, test_all_set, config)
    logging.critical('Accuracy of {} on all examples of WinoLogic: {}'.format(config.model_name, acc2))

    #torch.save(model.state_dict(), config.save_dir + '/pytorch_model.bin')
    
    return acc1, acc2
    

def train_and_test_five_times(config: ExpConfig):
        
    # train and test 5 times for this model
    logging.critical('learning rate = {}, number of training epochs = {}'.format(config.learning_rate, config.num_training_epochs))
    n_times = 1
    winowhy_accuracies = list()
    winologic_all_accuracies = list()
    
    for i in range(n_times):
        logging.critical('Experiment {}: {}'.format(i+1, config.task_name))
        acc1, acc2 = train_and_test(config)
        logging.critical('Testing {}: Acc on Winowhy = {}; Acc on all examples of WinoLogic = {}'.format(i+1, acc1, acc2))
        winowhy_accuracies.append(acc1)
        winologic_all_accuracies.append(acc2)
        
    return winowhy_accuracies, winologic_all_accuracies

model_name = 'gpt2'
seed = 42 # 42, 9334, 1718, 3149, 8747
gpu_id = 1
learning_rate = 1e-3
num_training_epochs = 30

logger = logging.getLogger()
logger.setLevel(logging.ERROR)
fh = logging.FileHandler(ROOT_PATH + LOG_DIR + '/test.log')
ch = logging.StreamHandler()
logger.addHandler(ch)
logger.addHandler(fh)

logging.critical('Seed: {}'.format(seed))

config = ExpConfig()
config.training_set_path = ROOT_PATH + WINOWHY_PATH
config.test_set_path = ROOT_PATH + WINOLOGIC_PATH
config.save_dir = ROOT_PATH + SAVE_DIR
config.task_name = "Train on WinoWhy, then test on WinoLogic"
config.set_seed(seed)
config.set_gpu_if_possible(gpu_id)
config.learning_rate = learning_rate
config.num_training_epochs = num_training_epochs
config.model_name = model_name
config.model_path = CACHE_DIR

train_and_test_five_times(config)

