# -*- coding: UTF-8 -*-

import os
import re

import json
import spacy
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.metrics import classification_report
from dataclasses import dataclass
from typing import Iterable, Any


CORPUS_PATH = '../../res/corpus/txt_preprocessed_v3'
CORPUS_PATH = '../../res/corpus/txt_preprocessed_v4'
CORPUS_PATH = '../../res/corpus/category_annotated_corpus_v4.1_20210513'
JSON_PATH = os.path.join(CORPUS_PATH, 'json2')

@dataclass
class Dataset:
    X: Iterable = tuple()
    Y: Iterable = tuple()


class Classifier:
    nlp = spacy.load('en_core_web_sm')
    corpus = []
    dataset = Dataset()
    test_set = Dataset()
    train_set = Dataset()
    def __init__(self, test_size=.3, random_state=42):
        self.vectorizer = TfidfVectorizer(stop_words='english') #, tokenizer=self.tokenize)
        self.test_size = test_size
        self.random_state = random_state


    def set_test_XY(self, X, Y):
        self.test_set.X = X
        self.test_set.Y = Y


    def set_train_XY(self, X, Y):
        self.train_set.X = X
        self.train_set.Y = Y

    
    def split_dataset(self):
        X_train, X_test, y_train, y_test = train_test_split(self.dataset.X, self.dataset.Y, test_size=self.test_size, random_state=self.random_state, stratify=self.dataset.Y)
        self.set_train_XY(X_train, y_train)
        self.set_test_XY(X_test, y_test)


    def vectorize(self):
        self.train_set.X = self.vectorizer.fit_transform(self.train_set.X)


    def fit(self):
        self.model.fit(self.train_set.X, self.train_set.Y)


    def tokenize(self, sent):
        for token in self.nlp(sent):
            if not token.is_stop and not token.is_digit:
                yield token.lemma_


    def load_dataset(self):
        try:
            root, dirs, fnames = next(os.walk(os.path.join(CORPUS_PATH, 'txt_preprocessed')))
            for fname in fnames:
                jname = fname.replace('.txt', '.json')
                num, cat, lang, sha256 = fname.split('-')
                if int(num) > 10000: continue
                with open(os.path.join(root, fname)) as f, open(os.path.join(JSON_PATH, jname)) as j:
                    j = json.load(j)
                    self.corpus.append((cat, j["title"] + ' ' + ' '.join(f.read().split())))
                    #print(num, cat, lang, sha256[:-4])
            
            self.dataset.Y = tuple(map(lambda x: x[0], self.corpus))
            self.dataset.X = tuple(map(lambda x: x[1], self.corpus))
        except Exception as exception:
            print('Can\'t load the dataset!')


    def predict(self):
        X = self.vectorizer.transform(self.test_set.X)
        y_pred = self.model.predict(X)
        return classification_report(self.test_set.Y, y_pred, labels=np.unique(y_pred))

