# -*- coding: UTF-8 -*-

import os
import time
import torch
import spacy
import functools
import torch.nn.functional as F

from torchtext.data.utils import get_tokenizer, ngrams_iterator
from classifier import Classifier, CORPUS_PATH, classification_report

from torch import nn, autograd
from torchtext import data
from torchtext import datasets
from torchtext.vocab import build_vocab_from_iterator

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from transformers import BertTokenizer, BertForSequenceClassification
#from torch.multiprocessing import set_start_method
#try:
#    set_start_method('spawn')
#except RuntimeError:
#    pass

device = torch.device('cuda:7')
nlp = spacy.load('en_core_web_sm')

class CNN(nn.Module, Classifier):
    def __init__(self, vocab_size, embedding_dim, num_filters, num_class, window_sizes=(3, 4, 5)):
        nn.Module.__init__(self)
        Classifier.__init__(self)

        #self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, [window_size, embedding_dim], padding=(window_size - 1, 0)) \
                    for window_size in window_sizes])
        self.fully_connected = nn.Linear(num_filters * len(window_sizes), num_class)

    
    def forward(self, text, offset):
        embedded = self.embedding(text)
        out = torch.unsqueeze(embedded, dim=1)
        outs = []
        for conv in self.convs:
            out2 = F.relu(conv(out))
            out2 = torch.squeeze(out2, -1)
            out2 = F.max_pool1d(out2, out2.size(2))
            outs.append(out2)
        out = torch.cat(outs, 2)

        out = out.view(out.size(0), -1)
        logits = self.fully_connected(out)

        # probs = F.softmax(logits)
        # classes = torch.max(probs, 1)[1]
        # return probs, classes
        return logits


def train(model, dataloader, optimizer, criterion, epoch):
    model.train()
    total_acc = 0
    total_cnt = 0
    log_interval = 50
    start_time = time.time()
    for idx, (label, text, offset) in enumerate(dataloader):
        try:
            optimizer.zero_grad()
            text = nn.utils.rnn.pad_sequence([text[indices[0] : indices[1]] for indices in zip(offset.tolist(), offset[1:].tolist() + [-1])], batch_first=True)
            #print('Label:', label)
            #print('Text: ', text)
            out = model(text, offset)
            loss = criterion(out, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), .1)
            optimizer.step()
            total_acc += (out.argmax(1) == label).sum().item()
            total_cnt += label.size(0)
            if idx % log_interval == 0 and idx > 0:
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc/total_cnt))
                total_acc, total_cnt = 0, 0
                start_time = time.time()
        except Exception as exception:
            import traceback
            traceback.print_exc()
            print(exception)


def evaluate(model, dataloader, criterion, clz):
    model.eval()
    total_acc = 0
    total_cnt = 0
    #total_accuracy = []
    labels  = []
    outputs = []
    with torch.no_grad():
        for idx, (label, text, offset) in enumerate(dataloader):
            text = nn.utils.rnn.pad_sequence([text[indices[0] : indices[1]] for indices in zip(offset.tolist(), offset[1:].tolist() + [-1])], batch_first=True)
            out  = model(text, offset)
            loss = criterion(out, label)
            total_acc += (out.argmax(1) == label).sum().item()
            total_cnt += label.size(0)
            labels.extend(map(lambda x: clz[x], label.tolist()))
            outputs.extend(map(lambda x: clz[x], out.argmax(1).tolist()))
        print(classification_report(labels, outputs, zero_division=0, digits=4))
    return total_acc / total_cnt


def predict(model, sentence):
    pass


def txt_iterator(data_path, ngrams, tokenizer):
    enclosing, _, fnames = next(os.walk(data_path))
    for fname in fnames:
        with open(os.path.join(enclosing, fname)) as f:
            num, clz, lang, sha256 = fname.split('-')
            if int(num) > 10000: continue
            tokens = tokenizer(' '.join(f.read().split())) if tokenizer else list(map(lambda x: x.strip(), f.read().split()))
            yield ngrams_iterator(tokens, ngrams)


def _spacy_tokenize(sent):
    return [token.lemma_ for token in nlp(sent) if not token.is_stop]


def collate_func(pipelines, batch):
    label, text, offset = [], [], [0]
    for label_, text_ in batch:
        temp = torch.tensor(pipelines['text'](text_), dtype=torch.int64, device=device)
        text.append(temp)
        label.append(torch.tensor(pipelines['label'](label_), device=device))
        offset.append(temp.size(0))
    text   = torch.cat(text)
    label  = torch.tensor(label, dtype=torch.int64, device=device)
    offset = torch.tensor(offset[:-1], device=device).cumsum(dim=0)
    return label, text, offset


def main():
    #classes  = ('Others', 'Gambling', 'Hacking', 'Porn', 'Drugs', 'Violence', 'Arms', 'Financial', 'Crypto', 'Goods', 'Leaks', 'Multiple', 'Electronic') 
    classes  = ('Others', 'Gambling', 'Hacking', 'Porn', 'Drugs', 'Violence', 'Financial', 'Crypto', 'Electronic', 'Arms') 
    #clz_dict = {item: i for i, item in enumerate(classes)}
    tokenizer = None #_spacy_tokenize #get_tokenizer('spacy', 'en_core_web_sm')
    vocab = build_vocab_from_iterator(txt_iterator(os.path.join(CORPUS_PATH, 'txt_preprocessed'), ngrams=2, tokenizer=tokenizer))
    text_pipeline  = lambda x: [vocab[token] for token in (tokenizer(x) if tokenizer else x)][:256]
    label_pipeline = lambda x: classes.index(x)

    model = CNN(vocab_size=len(vocab), embedding_dim=128, num_filters=128, num_class=len(classes))
    model.load_dataset()
    model.split_dataset()
    model.to(device)
    #linear.cuda()

    print(model)
    

    train_set = tuple(zip(model.train_set.Y, model.train_set.X))
    test_set  = tuple(zip(model.test_set.Y,  model.test_set.X))
    dataloader = {
            'train_set': DataLoader(train_set, batch_size=32, collate_fn=functools.partial(collate_func, {'text': text_pipeline, 'label': label_pipeline})),
            'test_set' : DataLoader(test_set,  batch_size=32, collate_fn=functools.partial(collate_func, {'text': text_pipeline, 'label': label_pipeline}))
    }
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1.5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

    N_EPOCHS = 50
    total_accu = None

    for epoch in range(N_EPOCHS):
        epoch_start_time = time.time()
        train(model, dataloader['train_set'], optimizer, criterion, epoch)
        accu_val = evaluate(model, dataloader['test_set'], criterion, classes)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 60)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
        print('-' * 60)


if __name__ == '__main__':
    main()



