import torch
from transformers import GPT2Tokenizer, GPT2Model
from DataSequence import DataSequence
import numpy as np


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2", 
                                  output_attentions=False,
                                  output_hidden_states=False)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    
model.to(device)

def get_hidden_states(model, story_ids, max_len):
    T1 = story_ids.shape[1]
    story_ids = story_ids.to(device)
    
    if T1 <= max_len:
        return model(input_ids=story_ids, return_dict=False)[0]
    else:  # T1 > max_len        
        story_vecs = model(input_ids=story_ids[:, 0:max_len], return_dict=False)[0]  # 1 x max_len
        start_idx = 0
        
        steps = T1 // max_len - 1
        for i in range(1, steps+1):
            start_idx = i*max_len
            story_vecs = torch.cat([story_vecs,
                                    model(input_ids=story_ids[:, start_idx:start_idx+max_len], return_dict=False)[0]],
                                   dim=1)

        story_vecs = torch.cat([story_vecs,
                                model(input_ids=story_ids[:, start_idx+max_len:], return_dict=False)[0]],
                               dim=1)
        
        return story_vecs
    

def get_vec_per_word(story_vecs, story_tokens):
    story_vecs = torch.tensor(story_vecs)  # 1xT'xn
    final_story_vecs = story_vecs[:, 0, :]  # 1xn
    
    num_subs = 1
    for i in range(1, len(story_tokens)):
        if (story_tokens[i] == 'Ġ') and (story_tokens[i-1] == 'Ġ'):  # Edge case for story_06
            continue
        
        if story_tokens[i][0] == 'Ġ':
            num_subs = 1  # Number of subwords detected for current word 
            final_story_vecs = torch.cat([final_story_vecs, story_vecs[:, i, :]], dim=0)
        else:
            # Running average
            final_story_vecs[-1, :] = (final_story_vecs[-1, :]*num_subs + story_vecs[:, i, :])/(num_subs+1)
            num_subs += 1
        
    return final_story_vecs  # txn
    

def make_semantic_model_gpt(wordseq):
    story_list = wordseq.data  # T
    story_str = ' '.join(story_list)
    
    story_tokens = tokenizer.tokenize(story_str)  # T'
    story_ids = tokenizer.batch_encode_plus([story_str],
                                            add_special_tokens=False,
                                            return_attention_mask=False,
                                            return_tensors="pt")["input_ids"]  # 1xT'
    with torch.no_grad():
        
        story_vecs = get_hidden_states(model, story_ids, tokenizer.model_max_length)  # 1xT'xn
        print(story_ids.shape, story_vecs.shape)
        assert story_ids.shape[1] == story_vecs.shape[1], f"{story_ids.shape[1]}, {story_vecs.shape[1]}"

        final_story_vecs = get_vec_per_word(story_vecs, story_tokens)  # Txn
        print(len(story_list), final_story_vecs.shape)
        assert len(story_list) == final_story_vecs.shape[0], f"{len(story_list)}, {final_story_vecs.shape[0]}"

    return DataSequence(np.array(final_story_vecs.to("cpu")), wordseq.split_inds, wordseq.data_times, wordseq.tr_times)

if __name__ == "__main__":
    print(tokenizer.model_max_length)

