import torch
import json
import time
import os
import traceback

from tqdm import tqdm
from random import shuffle, randint
from transformers import BertTokenizer, BertModel

from source.model import noname
from source.parser import parse
from source.count import counter
from source.sampler import sampler, sampler_multidataset
from source.function import *
from test import test, test_zero_shot

def train(config, model, data, types, epoch_num : int, lr : float = 1e-5):
    model.train()

    for i in range(epoch_num, epoch_num+1):
        print("\n===============Epoch.{}===============\n".format(i))

        optim = torch.optim.Adam(params=model.parameters(), lr = lr)
        batch_size = config.batch_size
        my_sampler = sampler(data,types,batch_size)

        cnter = counter()
        item_cnt = 0
        avg_loss = 0
        except_cnt = 0

        for samp in my_sampler.n_way_k_shot():

            optim.zero_grad()

            try:
                model_input, mask_tensor, ans, pos = get_input(samp , tokenizer, types)
                model_input = model_input.to(config.dev)
                mask_tensor = mask_tensor.to(config.dev)
                sim = get_sim(samp).to(config.dev)
                ans = ans.to(config.dev)

                model_output,embeddings = model(model_input, pos, mask_tensor = mask_tensor)
                loss = loss_function(model_output, ans)
                
                cnter.count(model_output.view(-1).tolist(),ans.view(-1).tolist())

                loss /= len(samp)
                avg_loss += loss.item()
                loss.backward()

                optim.step()
            
                item_cnt += 1
            except Exception as e:
                traceback.print_exc()
                except_cnt += 1

        print('avg_loss = ', avg_loss/item_cnt)
        print('except = ', except_cnt)
        cnter.output()
        cnter.clear()

def contrastive_train_batch(config, model, data, types, epoch_num : int , lr : float = 1e-5):
    model.train()

    for i in range(epoch_num):
        print("\n===============Epoch.{}===============\n".format(i))

        optim = torch.optim.Adam(params=model.parameters(), lr = lr)
        avg_loss = 0
        cnt = 0

        succeed_cnt = 0
        failed_cnt = 0

        my_sampler = sampler_multidataset(data, types, batch_size = config.batch_size)
        print('ok')

        for samp in my_sampler.get_batches():
            try:
                optim.zero_grad()

                model_input, mask_tensor, ans, pos = get_input(samp, tokenizer, types, True)
                model_input = model_input.to(config.dev)
                mask_tensor = mask_tensor.to(config.dev)
                sim = get_sim(samp).to(config.dev)

                model_output, embeddings = model(model_input, pos, mask_tensor = mask_tensor)
                loss = sim_loss_function(embeddings, sim)

                avg_loss += loss.item()
                loss.backward()
                optim.step()

                cnt += 1
                succeed_cnt += 1
            except Exception as e:
                print('error')
                failed_cnt += 1

            if cnt == 100 :
                print('avg_loss = ', avg_loss/cnt)
                print(succeed_cnt, ':', failed_cnt)
                succeed_cnt = 0
                failed_cnt = 0
                avg_loss = 0 
                cnt = 0

        print('avg_loss = ', avg_loss/cnt)
        avg_loss = 0 
        cnt = 0

def main():
    global tokenizer

    config = parse()
    config.dev = 'cuda:'+ str(config.cuda[0])
    datapath = './data'

    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    tokenizer.add_special_tokens({'additional_special_tokens':["<ent>","<blank>"]})

    types = loads_from_file(os.path.join(datapath, config.dataset, 'types.json'))
    data = open(os.path.join(datapath, config.dataset, config.data + '.json'), 'r').readlines() 

    model = noname(len(tokenizer), len(types)).to(config.dev)
    if config.load:
        model_tmp = torch.load(os.path.join('./model', config.load + '.pth'), map_location=lambda storage, loc: storage).to(config.dev)
        model.bert = model_tmp.bert
        

    model = torch.nn.DataParallel(model, device_ids = config.cuda)

    if config.con:
        c_data = []
        c_data.append(open(os.path.join(datapath , 'distant_zh_short.json'), 'r').readlines())
#        c_data.append(open(os.path.join(datapath , 'distant.json'), 'r').readlines())
#        c_data.append(open(os.path.join(datapath , 'distant_cross.json'), 'r').readlines())
        print('loaded')

        model.module.activate_bert_fine_tuning(True)
        contrastive_train_batch(config, model, c_data, types, 1, lr=1e-6)

        torch.save(model.module, os.path.join('./model', config.save + '.pth'))
    elif config.zero_shot:
        model.module.activate_bert_fine_tuning(False)
        maxf1 = test_zero_shot(config, datapath, model, tokenizer, data, types)
        print('maxf1 = ',maxf1)
        if config.log:
            print(maxf1, file = open(os.path.join('./result/', config.log), 'a'))
    else:
        model.module.activate_bert_fine_tuning(False)
        for i in range(1):
            train(config, model, data, types, i, lr = 1e-2)
#            test(model, datapath, tokenizer, types, config)
        model.module.activate_bert_fine_tuning(True)

        maxf1 = (0,0,0)

        for i in range(config.epoch):
            train(config, model, data, types, i, lr = 1e-5)

            if i == config.epoch-1 or (config.showall and True):
                model.module.activate_bert_fine_tuning(False)
                maxf1 = max(maxf1, test(model, datapath, tokenizer, types, config))
                model.module.activate_bert_fine_tuning(True)
                print('===========')

        print('maxf1 = ',maxf1)
        if config.log:
            print(maxf1, file = open(os.path.join('./result/', config.log), 'a'))

        torch.save(model.module, os.path.join('./model/', config.save + '.pth'))

if __name__ == '__main__':
    main()
