import os
import sys
import json
import pandas as pd
from typing import List, Dict
from collections import OrderedDict
from tqdm import tqdm
from datetime import datetime
import torch
import torch.nn as nn
from torch import Tensor
import wandb
from data_fns import prep_wikipedia_data_for_pt_bienc, prep_newspaper_data_pytorch, prep_newspaper_wiki_data
from data_fn_in_context import prep_wikipedia_data_coref_hn_for_pt_bienc
import pandas as pd
import os
import sys
from tqdm import tqdm

from transformers import AutoTokenizer
import os
import pickle

import wandb
from sentence_transformers import SentenceTransformer
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)


from nlp_utils.pytorch_bienc.train_bienc_token import train_biencoder_custom
from nlp_utils.data_fns import find_sep_token, featurize_text

# Import other necessary modules...

def main(sweep_config=None):
    # Set up model and other parameters...

    # Set up wandb sweep configuration
    wandb.init()
    if sweep_config is None:
        config = wandb.config
    
    ##When config is not none, allow sweep_config to be used directly
    ##Convert sweep_config["parameters"] to something useful
    
    else:
        ##Convert sweep config
            
        config=sweep_config["parameters"]
        ##Convert dict to an object containing attributes which are keys and their values of the key in the dict
        class Struct:
            def __init__(self, **entries):
                self.__dict__.update(entries)
        config=Struct(**config)
        
        ##Convert 'values' to the value itself
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==dict:
                config.__dict__[key]=config.__dict__[key]['values']
                
        ##If list, take the first value
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==list:
                config.__dict__[key]=config.__dict__[key][0]
        
                
                
        
    config.loss_fn = 'contrastive_batchhard'
    config.dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/'
    # Add other configuration options...

    # Extract and featurize data...
    # train_data, dev_data, _ = prep_wikipedia_data_for_pt_bienc(
    #     config.dataset_path,
    #     config.model,
    #     config.special_tokens,
    #     featurisation=config.featurisation,
    #     disamb_or_coref='coref',
    #     batch_type=config.loss_fn,
    #     samples_per_label=4,
    #     batch_size=config.batch_size,
    #     small=True
    # )
    #
    # prep_wikipedia_data_for_pt_bienc 
    train_data, dev_data, _ =  prep_wikipedia_data_coref_hn_for_pt_bienc(
        config.dataset_path,
        "sentence-transformers/all-mpnet-base-v2",
        {'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation=config.featurisation,
        batch_type=config.loss_fn,
        samples_per_label=4,
        batch_size=256,
        small=False,
        save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn.pkl"

    )

    # Train...
    ##We will chunk the training data into 10 chunks. Let's call it chunk_epoch
    ##First let's chunk the train_data into 10 chunks
    len_train_data=len(train_data["sentence_1"])
    chunk_indices=[i for i in range(0,len_train_data,int(len_train_data/10))]
    
    
    # ##Save chunks
    # print("Saving chunks")
    # for i in tqdm(range(len(chunk_indices)-1)):
    #     with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/train_data_chunk_{i}.pkl','wb') as f:
    #         train_data_chunk={}
    #         for key in train_data.keys():
    #             train_data_chunk[key]=train_data[key][chunk_indices[i]:chunk_indices[i+1]]
    #         pickle.dump(train_data_chunk,f)
    #     del train_data_chunk
        
    
    del train_data
    ##Train each chunk
    for i in tqdm(range(len(chunk_indices)-1)):
        if i <= 4:
            continue
        print("Loading chunk: ",i)
        
        with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/train_data_chunk_{i}.pkl','rb') as f:
            train_data_chunk=pickle.load(f)
        
        print("Training chunk: ",i)
        train_biencoder_custom(
            train_data_chunk,
            dev_data,
            wandb_project="biencoder_pt_hn",
            wandb_run_name="biencoder_pt_1_hn",
            model_name=config.model if i==0 else config.model_save_path_name+"_epoch_1",
            wandb_log=True,
            warmup_perc=config.warm_up_perc,
            batch_size=config.batch_size,
            num_epochs=config.epochs,
            margin=config.margin,
            max_seq_len=config.max_seq_length,
            inter_eval_steps=500,
            special_tokens=config.special_tokens,
            learning_rate=config.lr,
            model_save_path=config.model_save_path_name,
            sample_eval_loader_steps=30,
            shuffle_train=False,
            start_epoch_step=0
            
        )    
        
        ###Save a text file with the index of the chunk that has been trained
        with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/chunk_index.txt','w') as f:
            f.write(str(i))
            
    
    


    
    
     


    # train_biencoder_custom(
    #     train_data,
    #     dev_data,
    #     wandb_project="biencoder_pt",
    #     wandb_run_name="biencoder_pt_1",
    #     model_name= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model", #"sentence-transformers/all-mpnet-base-v2",
    #     wandb_log=False,
    #     warmup_perc=0.3, # 0.18 0.3 was best
    #     batch_size=704,
    #     num_epochs=10,
    #     margin=0.4,
    #     max_seq_len=128,
    #     inter_eval_steps=176,
    #     special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
    #     learning_rate=2e-6,
    #     model_save_path='/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model',
    #     sample_eval_loader_steps=30,
    #     shuffle_train=False,
    #     start_epoch_step=481

    # )


def newspaper_main():
    wandb.init()
    config = wandb.config
    config.loss_fn = 'contrastive_batchhard'
    config.dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/labeled_datasets_full_extended.json'


    name=f'full_{config.featurisation}_{config.max_seq_length}_{config.batch_size}_{config.epochs}_{config.warm_up_perc}_{config.date_featurisation}'
    print("Training the model with name: ", name)
    model=SentenceTransformer("/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model")
    train_data, dev_data, test_data = prep_newspaper_data_pytorch(
        config.dataset_path,
        model,
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation=config.featurisation,
        date_featurisation=config.date_featurisation,
        disamb_or_coref='coref')
    
    del model
    print(train_data.keys())
    print(dev_data.keys())

    ##Save train, dev and test data
    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/train_data.pkl','wb') as f:
        pickle.dump(train_data,f)
    


    train_biencoder_custom(
        train_data,
        dev_data,
        test_data,
        wandb_project="biencoder_pt",
        wandb_run_name="biencoder_pt_1",
        model_name= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model", #"sentence-transformers/all-mpnet-base-v2",
        wandb_log=True,
        warmup_perc=config.warm_up_perc, # 0.18 0.3 was best
        batch_size=config.batch_size, 
        num_epochs=config.epochs,
        margin=config.margin,
        max_seq_len=config.max_seq_length,
        inter_eval_steps=10,
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        learning_rate=config.lr,
        model_save_path=config.model_save_path_name,
        sample_eval_loader_steps=1000000,
        shuffle_train=False
    )
    
def newspaper_wiki_main(sweep_config=None):
    # Set up model and other parameters...

    # Set up wandb sweep configuration
    wandb.init()
    if sweep_config is None:
        config = wandb.config
    
    ##When config is not none, allow sweep_config to be used directly
    ##Convert sweep_config["parameters"] to something useful
    
    else:
        ##Convert sweep config
            
        config=sweep_config["parameters"]
        ##Convert dict to an object containing attributes which are keys and their values of the key in the dict
        class Struct:
            def __init__(self, **entries):
                self.__dict__.update(entries)
        config=Struct(**config)
        
        ##Convert 'values' to the value itself
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==dict:
                config.__dict__[key]=config.__dict__[key]['values']
                
        ##If list, take the first value
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==list:
                config.__dict__[key]=config.__dict__[key][0]
        
    config.loss_fn = 'contrastive_batchhard'
    config.dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/sample_fine_tuning_data.json'


    name=f'full_newspaper_wiki_{config.featurisation}_{config.max_seq_length}_{config.batch_size}_{config.epochs}_{config.warm_up_perc}_{config.date_featurisation}'
    print("Training the model with name: ", name)
    model_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext"
    model=SentenceTransformer(model_path)
    train_data, dev_data, test_data = prep_newspaper_wiki_data(
        config.dataset_path,
        model,
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation=config.featurisation,
        date_featurisation=config.date_featurisation,
)
    
    del model
    print(train_data.keys())
    print(dev_data.keys())
    print(train_data['sentence_1'][0])
    print(dev_data['sentence_2'][0])
    print(test_data['labels'][0])

    ##Save train, dev and test data
    with open(f'/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/train_data.pkl','wb') as f:
        pickle.dump(train_data,f)
    


    train_biencoder_custom(
        train_data,
        dev_data,
        test_data,
        wandb_project="newspaper_wiki_coref_disamb",
        wandb_run_name="newspaper_wiki_coref_disamb",
        model_name= model_path, #"sentence-transformers/all-mpnet-base-v2",
        wandb_log=True,
        warmup_perc=config.warm_up_perc, # 0.18 0.3 was best
        batch_size=config.batch_size, 
        num_epochs=config.epochs,
        margin=config.margin,
        max_seq_len=config.max_seq_length,
        inter_eval_steps=25,
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        learning_rate=config.lr,
        model_save_path=config.model_save_path_name,
        sample_eval_loader_steps=1000000,
        shuffle_train=False
    )
        

    
if __name__ == '__main__':
    
    ##For wikipedia
    sweep_config = {
        'method': 'random',
        'metric': {
            'name': 'inter_eval/accuracy',
            'goal': 'maximize'
        },
        'early_terminate': {
            'type': 'hyperband',
            's': 2,
            'eta': 4,
            'max_iter': 27,
        },
        'parameters': {
            # 'model': {
            #     'values': ['sentence-transformers/all-mpnet-base-v2'] #'sentence-transformers/all-mpnet-base-v2',
            # },
            'featurisation': {
                'values': ['ent_mark']   
            },
            'date_featurisation': {
                'values': ['prepend_1']
            },
            'warm_up_perc': {
                'values':[0.182] ##Increase to 1 in intermediate checkpoints 
            }, #Best 0.182
            'batch_size': {
                'values': [256]
            },
            'epochs': {
                'values': [10]
            },
            'margin': {
                'values': [0.4]
            }, #best 0.4
            'max_seq_length': {
                'values': [256]
            },
            'special_tokens': {
                'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
            },
            'lr': {
                'values': [1e-5] #Best 2e-5 using 2e-6 after chunk
            },
            'model_save_path_name':{
                'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/entity_split_newspaper_wiki_coref_disamb_more_incontext']
            }
        }
    }

    
    
    
    # sweep_config = {
    #     'method': 'random',
    #     'metric': {
    #         'name': 'inter_eval/best_accuracy_eval',
    #         'goal': 'maximize'
    #     },
    #     'early_terminate': {
    #         'type': 'hyperband',
    #         's': 2,
    #         'eta': 4,
    #         'max_iter': 27,
    #     },
    #     'parameters': {
    #         'model': {
    #             'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model'] #'sentence-transformers/all-mpnet-base-v2',
    #         },
    #         'featurisation': {
    #             'values': ['prepend']   
    #         },
    #         'date_featurisation': {
    #             'values': ['prepend_1','none']
    #         },
    #         'warm_up_perc': {
    #             'distribution': 'uniform',
    #             'min': 0.0,
    #             'max': 1.0
    #         }, #Best 0.182
    #         'batch_size': {
    #             'values': [512,704,128,256]
    #         },
    #         'epochs': {
    #             'values': [10,15,20,25,30]
    #         },
    #         'margin': {
    #             'distribution': 'uniform',
    #             'min': 0.1,
    #             'max': 0.9
    #         }, #best 0.4
    #         'max_seq_length': {
    #             'values': [128]
    #         },
    #         'special_tokens': {
    #             'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
    #         },
    #         'lr': {
    #             'values': [2e-6,2e-5,2e-4,2e-7,1e-6] #Best 2e-5
    #         }
    #     }
    # }

    # sweep_config = {
    #     'method': 'random',
    #     'metric': {
    #         'name': 'inter_eval/accuracy',
    #         'goal': 'maximize'
    #     },
    #     'early_terminate': {
    #         'type': 'hyperband',
    #         's': 2,
    #         'eta': 4,
    #         'max_iter': 27,
    #     },
    #     'parameters': {
    #         'model': {
    #             'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model'] #'sentence-transformers/all-mpnet-base-v2',
    #         },
    #         'featurisation': {
    #             'values': ['ent_mark']   #
    #         },
    #         'date_featurisation': {
    #             'values': ['prepend_1']
    #         },
    #         'warm_up_perc': {
    #             'values':[0.182,0.3,0.4,0.5,0.6,0.7]
    #         }, #Best 0.182
    #         'batch_size': {
    #             'values': [256]
    #         },
    #         'epochs': {
    #             'values': [10]
    #         },
    #         'margin': {
    #             'values': [0.3,0.4,0.5,0.6,0.7]
    #         }, #best 0.4
    #         'max_seq_length': {
    #             'values': [128]
    #         },
    #         'special_tokens': {
    #             'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
    #         },
    #         'lr': {
    #             'values': [0.00002,2e-4,2e-3] #Best 2e-5
    #         },
    #         'model_save_path_name':{
    #             'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/ent_mark_sweep']
    #         }
    #     }
    # }


    # sweep_config = {
    #     'method': 'random',
    #     'metric': {
    #         'name': 'inter_eval/accuracy',
    #         'goal': 'maximize'
    #     },
    #     'early_terminate': {
    #         'type': 'hyperband',
    #         's': 2,
    #         'eta': 4,
    #         'max_iter': 27,
    #     },
    #     'parameters': {
    #         'model': {
    #             'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model'] #'sentence-transformers/all-mpnet-base-v2',
    #         },
    #         'featurisation': {
    #             'values': ['ent_mark']   #
    #         },
    #         'date_featurisation': {
    #             'values': ['prepend_1']
    #         },
    #         'warm_up_perc': {
    #             'values':[0.5]
    #         }, #Best 0.182
    #         'batch_size': {
    #             'values': [256]
    #         },
    #         'epochs': {
    #             'values': [10]
    #         },
    #         'margin': {
    #             'values': [0.4]
    #         }, #best 0.4
    #         'max_seq_length': {
    #             'values': [128]
    #         },
    #         'special_tokens': {
    #             'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
    #         },
    #         'lr': {
    #             'values': [0.002] #Best 2e-5
    #         },
    #         'model_save_path_name':{
    #             'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/ent_mark']
    #         }
    #     }
    # }


    # # Initialize wandb sweep
    # sweep_id = wandb.sweep(sweep_config, project='entity_coref_wiki')

    # # Run the sweep
    # wandb.agent(sweep_id, function=main, count=1)
    newspaper_wiki_main(sweep_config=sweep_config)

    


