import torch

import argparse
from dataset import Dataset
from model import DeepPunctuation, DeepPunctuationCRF
from config import *
import json

parser = argparse.ArgumentParser()
parser.add_argument('--classes-file', required=True, type=str, help='Path to classes file')
parser.add_argument("--dataset_path", type=str, required=True, help="Path to dataset to tag")
parser.add_argument("--base_model", type=str, required=True, help="Name of base model")
parser.add_argument("--weights_path", type=str, required=True, help="Path to weights file")
parser.add_argument("--output_file", type=str, required=True, help="Path to output file")
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
                    help='whether to use CRF layer or not')
parser.add_argument('--lstm-dim', default=-1, type=int,
                    help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
args = parser.parse_args()

# Load punctuation
with open(args.classes_file, 'r') as f:
    classes_list = json.load(f)

punctuation_dict = dict(zip(classes_list, range(len(classes_list))))


with open(args.dataset_path, 'r') as f:
    data = f.read()

tokens = []
for line in data.splitlines():
    tokens.append(line.split("\t")[0])

tokenizer = MODELS[args.base_model][1].from_pretrained(args.base_model)
token_style = MODELS[args.base_model][3]

test_set = Dataset(args.dataset_path, tokenizer=tokenizer, sequence_len=256, token_style=token_style, is_train=False, punctuation_dict=punctuation_dict)

# Data Loaders
data_loader_params = {
    'batch_size': 8,
    'shuffle': False,
    'num_workers': 0
}

test_loader = torch.utils.data.DataLoader(test_set, **data_loader_params)

# logs
model_save_path = args.weights_path

# Model
device = torch.device('cuda:0')
if args.use_crf:
    deep_punctuation = DeepPunctuationCRF(args.base_model, freeze_bert=False, lstm_dim=args.lstm_dim, punctuation_dict=punctuation_dict)
else:
    deep_punctuation = DeepPunctuation(args.base_model, freeze_bert=False, lstm_dim=args.lstm_dim, punctuation_dict=punctuation_dict)

deep_punctuation = deep_punctuation.to(device)

deep_punctuation.load_state_dict(torch.load(model_save_path), strict=False)

punctuation_dict_inv = dict(zip(punctuation_dict.values(), punctuation_dict.keys()))

predictions = []
words = []
with torch.no_grad():
    for x, y, att, y_mask in test_loader:
        x, y, att, y_mask = x.to(device), y.to(device), att.to(device), y_mask.to(device)
        y_mask = y_mask.view(-1)
        if args.use_crf:
            y_predict = deep_punctuation(x, att, y)
            y_predict = y_predict.view(-1)
            y = y.view(-1)
        else:
            y_predict = deep_punctuation(x, att)
            y = y.view(-1)
            y_predict = y_predict.view(-1, y_predict.shape[2])
            y_predict = torch.argmax(y_predict, dim=1).view(-1)

        y_mask = y_mask.reshape(x.shape)
        y_predict = y_predict.reshape(x.shape)
        
        for b in range(x.shape[0]):
            local_predictions = []
            for t in range(x.shape[1]):
                if y_mask[b, t] == 0:
                    continue
                    
                local_predictions.append(punctuation_dict_inv[y_predict[b, t].item()])
                
            predictions += local_predictions[1:-1]

assert len(predictions) == len(tokens)

with open(args.output_file, 'w') as f:
    for token, prediction in zip(tokens, predictions):
        f.write(f"{token}\t{prediction}\n")