import os
import sys
import json
import pandas as pd
from typing import List, Dict
from collections import OrderedDict
from tqdm import tqdm
from datetime import datetime
import torch
import torch.nn as nn
from torch import Tensor
import wandb
from data_fn_disamb import prep_wikipedia_data_disamb_hn_for_pt_bienc
from prep_disamb_training_data import prep_newspaper_disamb_ft_data
import pandas as pd
import os
import sys
from tqdm import tqdm

from transformers import AutoTokenizer
import os
import pickle

import wandb
from sentence_transformers import SentenceTransformer
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)


from nlp_utils.pytorch_bienc.train_bienc_token import train_biencoder_custom
from nlp_utils.data_fns import find_sep_token, featurize_text

# Import other necessary modules...

def main(sweep_config=None):
    # Set up model and other parameters...

    # Set up wandb sweep configuration
    wandb.init()
    if sweep_config is None:
        config = wandb.config
    
    ##When config is not none, allow sweep_config to be used directly
    ##Convert sweep_config["parameters"] to something useful
    
    else:
        ##Convert sweep config
            
        config=sweep_config["parameters"]
        ##Convert dict to an object containing attributes which are keys and their values of the key in the dict
        class Struct:
            def __init__(self, **entries):
                self.__dict__.update(entries)
        config=Struct(**config)
        
        ##Convert 'values' to the value itself
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==dict:
                config.__dict__[key]=config.__dict__[key]['values']
                
        ##If list, take the first value
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==list:
                config.__dict__[key]=config.__dict__[key][0]
        
                
                
        
    config.loss_fn = 'contrastive_batchhard'
    config.dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/'
    # Add other configuration options...

    # Extract and featurize data...
    train_data, dev_data, _ = prep_wikipedia_data_disamb_hn_for_pt_bienc(
        dataset_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/",
        model="sentence-transformers/all-mpnet-base-v2",
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation="ent_mark",
        date_featurisation="none",
        batch_type="contrastive_batchhard",
        samples_per_label=4,
        batch_size=256,
        small=False,
        save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs_within_hn_v3.pkl"
    )
    



    # train_biencoder_custom(
    #     train_data,
    #     dev_data,
    #     wandb_project="entity_coref_disamb",
    #     wandb_run_name="entity_coref_disamb",
    #     model_name= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext", #"sentence-transformers/all-mpnet-base-v2",
    #     wandb_log=False,
    #     warmup_perc=0.2, # 0.18 0.3 was best
    #     batch_size=256,
    #     num_epochs=2,
    #     margin=0.4,
    #     max_seq_len=256,
    #     inter_eval_steps=176,
    #     special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
    #     learning_rate=2e-6,
    #     model_save_path='/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_small',
    #     sample_eval_loader_steps=200,
    #     shuffle_train=False,
    #     start_epoch_step=0
    # )

    ###USe sweep config params instead
    train_biencoder_custom(
        train_data,
        dev_data,
        wandb_project="entity_coref_disamb",
        wandb_run_name="entity_coref_disamb",
        model_name=config.model, #"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext", #"sentence-transformers/all-mpnet-base-v2",
        wandb_log=True,
        warmup_perc=config.warm_up_perc, # 0.18 0.3 was best
        batch_size=config.batch_size,
        num_epochs=config.epochs,
        margin=config.margin,
        max_seq_len=config.max_seq_length,
        inter_eval_steps=176,
        special_tokens=config.special_tokens,
        learning_rate=config.lr,
        model_save_path=config.model_save_path_name,
        sample_eval_loader_steps=200,
        shuffle_train=False,
        start_epoch_step=0
    )
    
def newspaper_finetune(sweep_config=None):
    # Set up model and other parameters...

    # Set up wandb sweep configuration
    wandb.init()
    if sweep_config is None:
        config = wandb.config
    
    ##When config is not none, allow sweep_config to be used directly
    ##Convert sweep_config["parameters"] to something useful
    
    else:
        ##Convert sweep config
            
        config=sweep_config["parameters"]
        ##Convert dict to an object containing attributes which are keys and their values of the key in the dict
        class Struct:
            def __init__(self, **entries):
                self.__dict__.update(entries)
        config=Struct(**config)
        
        ##Convert 'values' to the value itself
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==dict:
                config.__dict__[key]=config.__dict__[key]['values']
                
        ##If list, take the first value
        for key in config.__dict__.keys():
            if type(config.__dict__[key])==list:
                config.__dict__[key]=config.__dict__[key][0]
        
                
                
        
    config.loss_fn = 'contrastive_batchhard'
    config.triple_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_disamb_data.pkl"
    wiki_firstpara_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/formatted_first_para_data_qid_template_people_with_qrank_3occupations.json"
    disamb_dict_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/negatives_family_disamb.pkl"
    reindexed_news_data_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_data_reindexed.json"


    ##Load the data
    train_data, dev_data, test_data = prep_newspaper_disamb_ft_data(
        wiki_firstpara_path,
        disamb_dict_path, 
        reindexed_news_data_path,
        featurisation="ent_mark",
        date_featurisation="none",
        special_tokens=config.special_tokens,
        model="sentence-transformers/all-mpnet-base-v2",
        max_seq_length=config.max_seq_length,
        output_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/newspaper_disamb_data_nowikihn_only_people_lesscontx.pkl"
    )
    
    ##Train
    train_biencoder_custom(
        train_data,
        dev_data,
        test_data,
        wandb_project="entity_coref_disamb",
        wandb_run_name="entity_coref_disamb",
        model_name=config.model, #"/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext", #"sentence-transformers/all-mpnet-base-v2",
        wandb_log=True,
        warmup_perc=config.warm_up_perc, # 0.18 0.3 was best
        batch_size=config.batch_size,
        num_epochs=config.epochs,
        margin=config.margin,
        max_seq_len=config.max_seq_length,
        inter_eval_steps=32,
        special_tokens=config.special_tokens,
        learning_rate=config.lr,
        model_save_path=config.model_save_path_name,
        sample_eval_loader_steps=200,
        shuffle_train=True,
        start_epoch_step=0
    )
        
    
    
if __name__ == '__main__':
    
    
    
    

    sweep_config = {
        'method': 'random',
        'metric': {
            'name': 'inter_eval/accuracy',
            'goal': 'maximize'
        },
        'early_terminate': {
            'type': 'hyperband',
            's': 2,
            'eta': 4,
            'max_iter': 27,
        },
        'parameters': {
            'model': {
                'values':['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_epoch_1'] #'sentence-transformers/all-mpnet-base-v2',['sentence-transformers/all-mpnet-base-v2'] , #
            },
            'featurisation': {
                'values': ['ent_mark']   #
            },
            'date_featurisation': {
                'values': ['none']
            },
            'warm_up_perc': {
                'values':[0.2]
            }, #Best 0.182
            'batch_size': {
                'values': [256]
            },
            'epochs': {
                'values': [3]
            },
            'margin': {
                'values': [0.4]
            }, #best 0.4
            'max_seq_length': {
                'values': [256]
            },
            'special_tokens': {
                'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
            },
            'lr': {
                'values': [2e-6] 
            },
            'model_save_path_name':{
                'values': ['/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model_ent_mark_incontext_disamb_tuned_nodate_shuffled_2e6']
            }
        }
    }


    # # Initialize wandb sweep
    newspaper_finetune(sweep_config=sweep_config)
    
    

    


