import os
import sys
import copy
import re

from utils import * 
from dataset import *
from model import *

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class DPCrossEntropyCriterionBatch(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input, model, dataset, device, config):
        """
            model is a S4Decoder model
            Let's think step by step:
                1. Get the p(subword|previous characters) by feeding characters into decoder and get the LogSoftmax output
                2. From the jump table, get all possible subwords from one position
                3. Calculate the loss
        """
        self.config = config
        model.zero_grad()
        model.to(device)

        max_len = input['max_len'] # int
        bsz = input['batch_size'] # int
        embeddings = input['embeddings'] # tensor of (bsz, dim)

        dim = embeddings.shape[-1] # int
        words = input['words'] # list of str
        freqs = input['freqs'] # list of int
        seg_poss = input['segs_poss'] # tensor of (bsz, max_len, max_len)
        words_with_tokens = input['words_with_tokens'] # list of list of str
        words_with_tokens_idx = input['words_with_tokens_idx'] # tensor of (bsz, max_len)
        seg_poss = torch.tensor(seg_poss).to(device)

        # 1. get p(subword|previous characters)
        log_probs_at_pos = torch.zeros(bsz, max_len, dataset.vocab_size).to(device) # list of tensor of (bsz, max_len, num_subwords)
        decoder_input = torch.tensor(words_with_tokens_idx).to(device) # tensor of (bsz, max_len)
        decoder_output = model(decoder_input, embeddings) # input (bsz, max_len), hidden (bsz, num_layers, dim), output (bsz, max_len, dim)
        log_probs_at_pos = decoder_output

        # 2. calculate the loss using DP
        seg_probs = torch.zeros(bsz, max_len).to(device) + 1e-30
        for end_pos in range(1, max_len): # OK
            pad_mask = words_with_tokens_idx[:,end_pos] == dataset.vocab['<pad>'] # tensor of (bsz), check the last character is <pad>
            pad_mask = torch.tensor(pad_mask).to(device)
            seg_probs_current = torch.zeros(bsz).to(device) + 1e-30 # for non pad ones, for pad ones it's seg_probs[end_pos-1]
            for start_pos in range(1, end_pos+1): 
                add_masks = (seg_poss[:,end_pos, start_pos] != 0).clone().detach()
                index = seg_poss[:,end_pos,start_pos].view(-1, 1) # of type tensor of (bsz)
                subword_log_probs_at_pos = log_probs_at_pos[:, start_pos-1].gather(1, index).view(-1) # of type tensor of (bsz, num_subwords)
                seg_probs_current += torch.exp(seg_probs[:,start_pos-1] + subword_log_probs_at_pos) * add_masks
            seg_probs_current = torch.log(seg_probs_current) 
            seg_probs[:,end_pos] = torch.where(pad_mask, seg_probs[:,end_pos-1], seg_probs_current) 

        word_probs = seg_probs[:, max_len-1] # (bsz)
        loss = -torch.sum(word_probs)/bsz
        return loss

def train_step(data, decoder, decoder_optimizer, criterion, config):
    decoder_optimizer.zero_grad()
    loss = criterion(data, decoder, dataset, device, config)
    loss.backward()
    if (config['gradient_clip']>0):
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), config['gradient_clip'])
    decoder_optimizer.step()
    return loss.item()

# train config
def gene_train_config():
    config = {}

    config['train'] = 1
    config['data_folder'] = "./data/alt"
    config['text_file'] = f"{config['data_folder']}/train.en"
    #config['text_file'] = f"{config['data_folder']}/dev.en"
    #config['text_file'] = f"{config['data_folder']}/test.en"
    config['freq_table_file'] = config['text_file'] + ".freq"
    config['embedding_file'] = config['text_file'] + ".embedding"
    config['char_dict_file']= config['text_file'] + ".char_dict"
    config['write_to_file'] = config['text_file'] + ".seg"
    config['write_to_file'] = None
    config['model_path'] = config['text_file'] + ".pt"
    config['vocab_size'] = 8000
    config['vocab_path'] = config['text_file'] + '.' + str(config['vocab_size']) + ".vocab"
    config['volt_flag'] = 0
    config['gpu'] = True

    config['dataset_size'] = 0 
    config['batch_size'] = 500
    config['model'] = 'transformer'
    config['no_encoder'] = 0
    config['hidden_size'] = 768
    config['num_layers'] = 3
    config['bidirectional'] = True

    config['loss_decay'] = 0
    config['length_award'] = False
    config['length_award_alpha'] = 0
    config['gradient_clip'] = 1.0 # default 1.0
    config['dropout'] = 0.1
    config['lr'] = 5e-4 # 5e-5
    config['use_frequency'] = "one"
    config["weighted_loss"] = False
    config['sort_by_length'] = False 
    config['sort_by_alphabet'] = False
    return config

config = gene_train_config()

print ("GPU number", torch.cuda.device_count())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = S4DatasetBatch(config)
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=8, collate_fn=dataset.collate_fn)

decoder = S4DecoderBatch(config['hidden_size'], len(dataset.vocab), num_layers = config["num_layers"], bidirectional=config["bidirectional"], padding_idx=dataset.vocab['<pad>'], dropout=config["dropout"], config=config).to(device)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=config['lr'])
decoder.train()

criterion = DPCrossEntropyCriterionBatch()

start_epoch = 0 
end_epoch = 15 
losses = []
for epoch in range(start_epoch, end_epoch):
    for batch_i, data in enumerate(dataloader):
        loss = train_step(data, decoder, decoder_optimizer, criterion, config)
        print (f"Epoch {epoch+1}/{end_epoch}, Batch {batch_i}/{len(dataloader)}, Loss {loss}")