import os
import sys
import copy
import re

from utils import * 
from dataset import *
from model import *
from train import *

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def evaluate_nbest(input, model, dataset, device, print_flag=0, config=None):
    with torch.no_grad():
        model.eval()
        max_len = input['max_len'] # int
        bsz = input['batch_size'] # int
        embeddings = input['embeddings'] # tensor of (bsz, dim)
        words = input['words'] # list of str
        freqs = input['freqs'] # list of int
        seg_poss = input['segs_poss'] # tensor of (bsz, max_len, max_len)
        words_with_tokens = input['words_with_tokens'] # list of list of str
        words_with_tokens_idx = input['words_with_tokens_idx'] # tensor of (bsz, max_len)
        seg_poss = torch.tensor(seg_poss).to(device)

        # 1. get p(subword|previous characters)
        decoder_input = torch.tensor(words_with_tokens_idx).to(device) # tensor of (bsz, max_len)
        decoder_output = model(decoder_input, embeddings) # input (bsz, max_len), hidden (bsz, num_layers, dim), output (bsz, max_len, dim)
        subword_probs = decoder_output[0] # (max_len, vocab_size) torch.Size([10, 7533])

        # paths = [] * max_len
        paths = [[] for _ in range(max_len)]
        paths[0].append(([], -1e-30))
        for end_pos in range(1, max_len):
            for start_pos in range(1, end_pos+1):
                can_choose_masks = (seg_poss[:,end_pos, start_pos] != 0).clone().detach()
                can_choose_masks = can_choose_masks[0]
                if (can_choose_masks == 0): continue
                index = seg_poss[0,end_pos,start_pos] # a number
                subword_log_probs_at_pos = subword_probs[start_pos-1][index]
                for (p, v) in paths[start_pos-1]:
                    paths[end_pos].append((p + [start_pos], v + subword_log_probs_at_pos))
            paths[end_pos] = sorted(paths[end_pos], key=lambda x: x[1], reverse=True)
            paths[end_pos] = paths[end_pos][:config['n_best']]

    # nbest segmentation
    word = words[0]
    words_with_tokens = words_with_tokens[0]
    decoded_subwords_v_list = []

    for (p, v) in paths[max_len-1]:
        decoded_subwords = []
        for i in range(len(p)-1):
            subword = "".join(words_with_tokens[p[i]:p[i+1]])
            decoded_subwords.append(subword)
        final_subword = "".join(words_with_tokens[p[-1]:])
        decoded_subwords.append(final_subword)
        decoded_subwords_v_list.append((decoded_subwords, v))
    return 0, word, decoded_subwords_v_list

def gene_test_config(text_file, vocab_path, model_path, output_file=None, train_config=None):
    config = train_config.copy()
    config['train'] = 0
    config['text_file'] = text_file
    config['freq_table_file'] = config['text_file'] + ".freq"
    config['embedding_file'] = config['text_file'] + ".embedding"
    config['char_dict_file']= config['text_file'] + ".char_dict"
    config['write_to_file'] = output_file

    config['model_path'] = model_path
    config['vocab_path'] = vocab_path # only need vocab to load subwords

    config['dataset_size']=0
    config['use_frequency'] = "one"
    config['batch_size'] = 1

    # decoding setting
    config['n_best'] = 10
    config['length_award'] = False
    config['length_award_alpha'] = 3
    return config

tot_segs = {}

text_file = ""
output_file = ""
vocab_path = ""
model_path = ""

train_config = gene_train_config()
test_config = gene_test_config(text_file, vocab_path, model_path, output_file, train_config)
if (test_config['gpu']):
    print ("GPU number", torch.cuda.device_count())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_dataset = S4DatasetBatch(test_config)
test_dataloader = DataLoader(test_dataset, batch_size=test_config["batch_size"], shuffle=False, num_workers=8, collate_fn=test_dataset.collate_fn)

decoder = S4DecoderBatch(test_config["hidden_size"], len(test_dataset.vocab), num_layers = test_config["num_layers"], bidirectional=test_config["bidirectional"], padding_idx=test_dataset.vocab['<pad>'], dropout=test_config["dropout"], config=test_config).to(device)
decoder.load_state_dict(torch.load(test_config["model_path"]))

for batch_i, data in enumerate(test_dataloader):
    word = data["words"][0]
    loss, word, decoded_subwords_list = evaluate_nbest(data, decoder, test_dataset, device, print_flag=0, config=test_config)
    tot_segs[word] = decoded_subwords_list
    for (decoded_subwords, v) in tot_segs[word]:
        v = str(round(v, 2))
        line = " ".join(decoded_subwords) + "\t" + v + "\n"
        print (line.strip())
