import os
import json
import argparse
import yaml
from collections import OrderedDict
from typing import List

from tqdm import tqdm
import pandas as pd
import torch


class FileModule:
    @staticmethod
    def read_file(filename: str):
        with open(filename, 'r', encoding='utf-8') as f:
            tmp = f.readlines()
        return [i[:-1] for i in tmp]

    @staticmethod
    def read_json(filename: str):
        if filename.endswith('.jsonl'):
            output = []
            with open(filename, 'r', encoding='utf-8') as f:
                for idx, line in tqdm(enumerate(f)):
                    line = line.strip()
                    line = json.loads(line)
                    output.append(line)
        else:
            with open(filename, 'r', encoding='utf-8') as f:
                output = json.load(f)
        return output

    @staticmethod
    def read_csv(filename: str):
        if filename.endswith('.tsv'):
            return pd.read_csv(filename, sep='\t', engine='python')
        else:
            return pd.read_csv(filename, engine='python')

    @staticmethod
    def write_lines(obj: List[str], filename):
        obj = list(map(lambda i: str(i).strip() + '\n', obj))
        with open(filename, 'w', encoding='utf-8') as f:
            f.writelines(obj)

    @staticmethod
    def write_json(obj: dict, filename):
        if filename.endswith('.jsonl'):
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump(obj, f, ensure_ascii=False)
                f.write('\n')
        else:
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump(obj, f, ensure_ascii=False)

    def reads(self, filename):
        if filename.endswith('.json') or filename.endswith('.jsonl'):
            return self.read_json(filename)
        elif filename.endswith('.tsv') or filename.endswith('.csv'):
            return self.read_csv(filename)
        return self.read_file(filename)


class TorchFileModule(FileModule):
    @staticmethod
    def find_best(dirname):
        out = None
        for files in os.listdir(dirname):
            if files.startswith('best_'):
                out = files
        return out

    @staticmethod
    def import_weight(old_state: str):
        translation_state_dict = torch.load(old_state)['model']

        new_state_dict = OrderedDict()
        new_state_dict.update({'model.shared.weight': translation_state_dict['encoder.embed_tokens.weight']})

        for item in translation_state_dict:
            new_state_dict.update({'model.' + item: translation_state_dict[item]})
        new_state_dict.update({'lm_head.weight': translation_state_dict['decoder.output_projection.weight']})
        return new_state_dict

    @staticmethod
    def select_parameters(state_dict, keys: List[str] = None):
        if keys is None:
            return state_dict
        else:
            new_state_dict = OrderedDict()
            for sd_key in state_dict:
                for key in keys:
                    if key in sd_key:
                        new_state_dict[sd_key] = state_dict[sd_key]
            return new_state_dict

    def save_one(self, plself, loss, bleu, filename):
        torch.save({'epoch': plself.current_epoch,
                    'step': plself.global_step,
                    'model_state_dict': plself.model.state_dict(),
                    'optimizer_state_dict': plself.optimizers().state_dict(),
                    'loss': loss,
                    'sacrebleu': bleu,
                    'args': plself.hparam_args},
                   filename)

    def ckpt_save(self, plself, loss, score, maximize=True, score_name=None, step=False, last=False):
        names = {'loss': str(format(loss, '.3f')),
                 'score': str(format(score, '.3f')),
                 'epoch': format(plself.current_epoch, '03'),
                 'step': format(plself.global_step, '07')}

        if score_name is None:
            score_name = 'score'

        if step:
            iteration_name = f'step={names["step"]}_{score_name}={names["score"]}.pt'
            best_name = f'best_{score_name}={names["score"]}_step={names["step"]}.pt'
        else:
            iteration_name = f'epoch={names["epoch"]}_{score_name}={names["score"]}.pt'
            best_name = f'best_{score_name}={names["score"]}_epoch={names["epoch"]}.pt'

        iteration_name = os.path.join(plself.ckpt_dir, iteration_name)
        best_name = os.path.join(plself.ckpt_dir, best_name)

        save_list = [i for i in os.listdir(plself.ckpt_dir) if ('.pt' in i) and (score_name in i)]

        if len(save_list) < plself.ckpt_save_num:
            self.save_one(plself, loss, score, filename=iteration_name)
        else:
            bleu_dict = {i: float((i.split('.pt')[0].split(f'_{score_name}=')[1])) for i in save_list if
                         ('best' not in i)}
            if maximize:
                minval = min(bleu_dict.values())
                if minval < score:
                    for fn in bleu_dict:
                        if bleu_dict[fn] == minval:
                            if os.path.exists(os.path.join(plself.ckpt_dir, fn)):
                                os.remove(os.path.join(plself.ckpt_dir, fn))
                            break
                    self.save_one(plself, loss, score, filename=iteration_name)
            else:  # minimize
                maxval = max(bleu_dict.values())
                if maxval > score:
                    for fn in bleu_dict:
                        if bleu_dict[fn] == maxval:
                            if os.path.exists(os.path.join(plself.ckpt_dir, fn)):
                                os.remove(os.path.join(plself.ckpt_dir, fn))
                            break
                    self.save_one(plself, loss, score, filename=iteration_name)

        best_bleu = None
        for i in save_list:
            if 'best' in i:
                best_bleu = i
                break

        if best_bleu is not None:
            best_bleu_score = float(best_bleu.split('_')[1].split('=')[1])
            if best_bleu_score < score:
                if os.path.exists(os.path.join(plself.ckpt_dir, best_bleu)):
                    os.remove(os.path.join(plself.ckpt_dir, best_bleu))
                self.save_one(plself, loss, score, filename=best_name)
        else:
            self.save_one(plself, loss, score, filename=best_name)

        if last:
            self.save_one(plself, loss, score, filename=os.path.join(plself.ckpt_dir, 'model_last.pt'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--src', default=None, type=str)
    args = parser.parse_args()
