from transformers import AutoTokenizer, AutoModel, BertConfig
import argparse
import pandas as pd
import math
import json
import numpy as np

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", default=None, type=str, required=True, help="Dataset csv file")
    parser.add_argument("-o", type=str, help="Output file, contains standard cols 0.3, plus json vectors")
    parser.add_argument("-b", type=int, default=50, help="Batchsize (50)")
    parser.add_argument("--maxtoks", type=int, default=200, help="Maximum number of tokens per sentence to use (200)")

    
    args = parser.parse_args()
    outfile = args.o
    batchsize = args.b
    maxtoks = args.maxtoks

df = pd.read_csv(args.i)

# config = BertConfig.from_pretrained("ai4bharat/indic-bert")
# config.output_hidden_states = False

tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-bert")
model = AutoModel.from_pretrained("ai4bharat/indic-bert")

'''Use this (sentences) to check how the output and inp variables look like: Comment out line 33 and uncomment 30 and 34 to use this'''
# sentences = ["പലദേശം. പല ഭാഷ ഒരേ ഒരു രാജാവ്  അല്ലാതെ  സ്വന്തം രാജവയത് അല്ല", "ഈ ഓണം ഏട്ടനും പിള്ളേർക്ക് ഉള്ളതാണ് എന്ന് ഉള്ളവർ ലൈക്‌ അടി"]

outs = []
tags = []
save_freq = 1024
chunk = 0
cnt = 0
for sentence, tag in zip(df['Sentence'], df['Tag']):
    cnt += 1
    tags.append(tag)
# for sentence in sentences:

    sents = [sentence]
    # print(sents)
    # outs = []
    for batchnr in range(math.ceil(len(sents)/batchsize)):

        fromidx = batchnr * batchsize
        toidx = (batchnr+1) * batchsize
        actualtoidx = min(len(sents), toidx)

        sentsbatch = sents[fromidx:actualtoidx]

        sent = "[CLS] " + sentsbatch[0] + " [SEP]"
        sent_tok = tokenizer(sent, return_tensors="pt")

        inp = sent_tok['input_ids']

        # print(inp)
        output = model(inp)
        # print(output["pooler_output"])

        outs.extend(output["pooler_output"])

    if cnt % save_freq == 0:
        # Convert tensor array to numpy array
        for i in range(len(outs)):
            outs[i] = outs[i].detach().numpy()
        outs = np.array(outs)
        # Saving data
        data = []
        for f, t in zip(outs, tags):
            data.append([f, t])

        data = pd.DataFrame(data, columns = ['Feature', 'Tag'])
        data.to_pickle(outfile + "_" + str(chunk))
        chunk += 1

        outs = []
        tags = []

        print("Saved", save_freq)

if len(df['Sentence']) % save_freq != 0:
    # Convert tensor array to numpy array
    for i in range(len(outs)):
        outs[i] = outs[i].detach().numpy()
    outs = np.array(outs)
    # Saving data
    data = []
    for f, t in zip(outs, tags):
        data.append([f, t])

    data = pd.DataFrame(data, columns = ['Feature', 'Tag'])
    data.to_pickle(outfile + "_" + str(chunk))

    print("Saved", len(df['Sentence']) % save_freq)