import os
import sys
import json
import pandas as pd
from typing import List, Dict
from collections import OrderedDict
from tqdm import tqdm
from datetime import datetime
import torch
import torch.nn as nn
from torch import Tensor
import wandb
from data_fns import prep_wikipedia_data_for_pt_bienc

import pandas as pd
import os
import sys
from tqdm import tqdm

from transformers import AutoTokenizer
import os

import wandb
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
grandparentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)
sys.path.append(grandparentdir)


from nlp_utils.pytorch_bienc.train_bienc_token import train_biencoder_custom
from nlp_utils.data_fns import find_sep_token, featurize_text

# Import other necessary modules...

def main():
    # Set up model and other parameters...

    # Set up wandb sweep configuration
    wandb.init()
    config = wandb.config
    config.loss_fn = 'contrastive_batchhard'
    config.dataset_path = '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/'
    # Add other configuration options...

    # Extract and featurize data...
    # train_data, dev_data, _ = prep_wikipedia_data_for_pt_bienc(
    #     config.dataset_path,
    #     config.model,
    #     config.special_tokens,
    #     featurisation=config.featurisation,
    #     disamb_or_coref='coref',
    #     batch_type=config.loss_fn,
    #     samples_per_label=4,
    #     batch_size=config.batch_size,
    #     small=True
    # )

    train_data, dev_data, _ = prep_wikipedia_data_for_pt_bienc(
        config.dataset_path,
        "sentence-transformers/all-mpnet-base-v2",
        {'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        featurisation='prepend',
        disamb_or_coref='coref',
        batch_type=config.loss_fn,
        samples_per_label=4,
        batch_size=640,
        small=False,
        save_path="/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/outputs.pkl"

    )

    # Train...
    # name = f'full_{config.featurisation}_{config.max_seq_length}_{config.batch_size}_{config.epochs}_{config.warm_up_perc}'
    print("Training...")
    # train_biencoder_custom(
    #     train_data,
    #     dev_data,
    #     wandb_project="biencoder_pt",
    #     wandb_run_name="biencoder_pt_1",
    #     model_name=config.model,
    #     wandb_log=False,
    #     warmup_perc=config.warm_up_perc,
    #     batch_size=config.batch_size,
    #     num_epochs=config.epochs,
    #     margin=config.margin,
    #     max_seq_len=config.max_seq_length,
    #     inter_eval_steps=config.batch_size/4,
    #     special_tokens=config.special_tokens,
    #     learning_rate=config.lr,
    #     model_save_path='/mnt/data01/entity/new_models_ab',
    #     generator_before_tokenization=True,
    #     tokenizer_chunk_size=1000000
    # )

    train_biencoder_custom(
        train_data,
        dev_data,
        wandb_project="biencoder_pt",
        wandb_run_name="biencoder_pt_1",
        model_name= "/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model", #"sentence-transformers/all-mpnet-base-v2",
        wandb_log=False,
        warmup_perc=0.3, # 0.18 0.3 was best
        batch_size=704,
        num_epochs=10,
        margin=0.4,
        max_seq_len=128,
        inter_eval_steps=176,
        special_tokens={'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"},
        learning_rate=2e-6,
        model_save_path='/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/entity_training/cgis_model',
        sample_eval_loader_steps=30,
        shuffle_train=False,
        start_epoch_step=481

    )

if __name__ == '__main__':
    # sweep_config = {
    #     'method': 'random',
    #     'metric': {
    #         'name': 'inter_eval/best_accuracy_eval',
    #         'goal': 'maximize'
    #     },
    #     'early_terminate': {
    #         'type': 'hyperband',
    #         's': 2,
    #         'eta': 4,
    #         'max_iter': 27,
    #     },
    #     'parameters': {
    #         'model': {
    #             'values': ['sentence-transformers/all-MiniLM-L6-v2', 'sentence-transformers/all-MiniLM-L12-v2', 'sentence-transformers/multi-qa-distilbert-cos-v1','sentence-transformers/multi-qa-MiniLM-L6-cos-v1'] #'sentence-transformers/all-mpnet-base-v2',
    #         },
    #         'featurisation': {
    #             'values': ['prepend']   
    #         },
    #         'warm_up_perc': {
    #             'distribution': 'uniform',
    #             'min': 0.0,
    #             'max': 0.5
    #         }, #Best 0.182
    #         'batch_size': {
    #             'values': [512]
    #         },
    #         'epochs': {
    #             'values': [2]
    #         },
    #         'margin': {
    #             'distribution': 'uniform',
    #             'min': 0.3,
    #             'max': 0.6
    #         }, #best 0.4
    #         'max_seq_length': {
    #             'values': [128]
    #         },
    #         'special_tokens': {
    #             'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
    #         },
    #         'lr': {
    #             'values': [2e-6,2e-5] #Best 2e-5
    #         }
    #     }
    # }

    # sweep_config = {
    #     'method': 'random',
    #     'metric': {
    #         'name': 'inter_eval/best_accuracy_eval',
    #         'goal': 'maximize'
    #     },
    #     'parameters': {
    #         'model': {
    #             'values': ['sentence-transformers/all-mpnet-base-v2'] #'',
    #         },
    #         'featurisation': {
    #             'values': ['prepend']   
    #         },
    #         'warm_up_perc': {
    #             'values': [0.182] #Best 0.182
    #         }, #Best 0.182
    #         'batch_size': {
    #             'values': [512]
    #         },
    #         'epochs': {
    #             'values': [2]
    #         },
    #         'margin': {
    #             'values': [0.4] #best 0.4
    #         }, 
    #         'max_seq_length': {
    #             'values': [128]
    #         },
    #         'special_tokens': {
    #             'values': [{'men_start': "[M]", 'men_end': "[/M]", 'men_sep': "[MEN]"}]
    #         },
    #         'lr': {
    #             'values': [2e-5] #Best 2e-5
    #         }
    #     }
    # }


    # # Initialize wandb sweep
    # sweep_id = wandb.sweep(sweep_config, project='entity_coref')

    # # Run the sweep
    # wandb.agent(sweep_id, function=main, count=1)

    main()


