#!/usr/bin/env python3

import time
import gzip 
import codecs
import pickle
import os, os.path
import sys

from random import choice, shuffle
from pathlib import Path
from collections import Counter

import numpy as np
import numpy.linalg as nl
import numpy.random as nr
import scipy.sparse as sp

from sklearn.svm import LinearSVC
from tqdm import tqdm
from cvxopt import spmatrix
from torch.utils.data import Dataset, DataLoader

class DataTorchSVMClass(object):
    def __init__(self, N, datadir, C, outputdir='./output'):
        self.N=N
        self.C=C
        self.datadir=datadir
        outputdir=Path(outputdir)
        argv0 = os.path.splitext(sys.argv[0])[0]
        metadir = outputdir / f"{argv0}.{datadir}" 
        self.metadir=metadir
        fname = self.metadir / f"N_{N}" / f"C_{C}"
        self.fname=fname
        print(f"init: Data will be stored in directory {fname.absolute()}")
        try: os.makedirs(fname)
        except FileExistsError: print(f"init: Directory {fname.absolute()} exists (this is fine)...")
        print("init: Done init")

    def stage(self):
        '''
        Build/save entities, realtions, independent of N
        '''
        datadir = self.datadir
        metadir = self.metadir

        traindata = os.path.sep.join([datadir, 'train.txt'])
        valdata = os.path.sep.join([datadir, 'valid.txt'])
        relationspath = os.path.sep.join([datadir, 'relations.dict'])
        entitiespath = os.path.sep.join([datadir, 'entities.dict'])

        # convert raw data to native Python objects
        # train
        print(f"stage: Loading kgtrain from {traindata}")
        with codecs.open(traindata, encoding='utf-8') as f: raw_data = f.read()
        rows = raw_data.strip('\n').split('\n')
        kgtrain = [tuple(row.strip().split('\t')) for row in rows]
        with gzip.open(self.metadir / 'kgtrain.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(kgtrain, fh)

        # val
        print(f"stage: Loading kgval from {valdata}")
        with codecs.open(valdata, encoding='utf-8') as f: raw_data = f.read()
        rows = raw_data.strip('\n').split('\n')
        kgval = [tuple(row.strip().split('\t')) for row in rows if row.strip()]
        with gzip.open(self.metadir / 'kgval.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(kgval, fh)

        # Relation load
        with codecs.open(relationspath, encoding='utf-8') as f: raw_data = f.read()
        rows = raw_data.strip().split('\n')
        rows = [row.strip().split('\t') for row in rows]
        rows = [(int(_[0]), _[1]) for _ in rows]
        Rname = list(set([_[1] for _ in rows]))
        # I2R:=index value to Relation value
        I2R = dict(rows)
        R2I = dict([(_[1], _[0]) for _ in rows])

        # known entity names are in the union of head names and tail names
        with codecs.open(entitiespath, encoding='utf-8') as f: raw_data = f.read()
        rows = raw_data.strip().split('\n')
        rows = [row.strip().split('\t') for row in rows]
        rows = [(int(_[0]), _[1]) for _ in rows]
        Ename = list(set([_[1] for _ in rows]))
        # I2E:=index value to Entity value
        I2E = dict(rows)
        E2I = dict([(_[1], _[0]) for _ in rows])
        assert len(Ename) == len(rows)

        k = len(Ename) + len(Rname)
        print(f'stage: Encoding in {k} dimensions')

        print(f"stage: Building entity and relation")

        # Build the embeddings.
        # "E": entities
        # "R": relations
        E,R = [None]*len(Ename), [None]*len(Rname)

        # 1-hot embedding for R
        for name in Rname: 
            val=spmatrix([],[],[], size=(k,1))
            val[R2I[name]]=1.
            R[R2I[name]] = val

        # 1-hot embedding for E
        for name in Ename: 
            val=spmatrix([],[],[], size=(k,1))
            val[len(R)+E2I[name]]=1.0
            E[E2I[name]] = val

        print("stage: Saving E to file")
        with gzip.open(self.metadir / 'E.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(E, fh)
        print("stage: Saving R to file")
        with gzip.open(self.metadir / 'R.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(R, fh)
        print("stage: Saving E2I to file")
        with gzip.open(self.metadir / 'E2I.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(E2I, fh)
        print("stage: Saving R2I to file")
        with gzip.open(self.metadir / 'R2I.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(R2I, fh)
        print("stage: Saving Ename to file")
        with gzip.open(self.metadir / 'Ename.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(Ename, fh)
        print("stage: Saving Rname to file")
        with gzip.open(self.metadir / 'Rname.pkl.gz', 'wb', compresslevel=1) as fh: pickle.dump(Rname, fh)
        print("stage: Done")

        print(f"stage: Building idx -> (i,j) map, dimension {k}")
        self.ij2idx = lambda i,j: k*(k-1)//2 - (k-i)*(k-i-1)//2 + j
        print("stage: Done with stage")

    def build(self):
        '''
        Build/save problems. Depends on stage and N, independent of C.
        '''
        metadir=self.metadir
        print("build: Building constraints")
        with gzip.open(metadir / 'kgtrain.pkl.gz', 'rb') as fh: kgtrain = pickle.load(fh)
        with gzip.open(metadir / 'kgval.pkl.gz', 'rb') as fh: kgval = pickle.load(fh)
        with gzip.open(self.metadir / 'Ename.pkl.gz', 'rb') as fh: Ename = pickle.load(fh)
        with gzip.open(self.metadir / 'Rname.pkl.gz', 'rb') as fh: Rname = pickle.load(fh)
        with gzip.open(self.metadir / 'E.pkl.gz', 'rb') as fh: E = pickle.load(fh)
        with gzip.open(self.metadir / 'R.pkl.gz', 'rb') as fh: R = pickle.load(fh)
        with gzip.open(self.metadir / 'E2I.pkl.gz', 'rb') as fh: E2I = pickle.load(fh)
        with gzip.open(self.metadir / 'R2I.pkl.gz', 'rb') as fh: R2I = pickle.load(fh)
        k = E[0].size[0]
        N = self.N

        # negative samples
        problemdatap = SVMDataset(kgtrain+kgval, Ename, Rname, E2I, R2I, E, R, N, 'pos')
        problemdatan = SVMDataset(kgtrain+kgval, Ename, Rname, E2I, R2I, E, R, N, 'neg')

        collate = lambda x: sp.vstack(sum((_[0] for _ in x), [])) 

        loaderp = DataLoader(problemdatap, batch_size=10, collate_fn=collate, num_workers=os.cpu_count()-1)
        loadern = DataLoader(problemdatan, batch_size=10, collate_fn=collate, num_workers=os.cpu_count()-1)

        print(f"build: building positive, negative triples with N={self.N}")
        t0=time.time()
        Z = []
        X = []
        with tqdm(total=len(loaderp)) as pbar:
            for _X in loaderp:
                Z.append(_X)
                pbar.update()

        with tqdm(total=len(loadern)) as pbar:
            for _X in loadern:
                X.append(_X)
                pbar.update()

        X = sp.vstack(X)
        Z = sp.vstack(Z)
        Xf = self.metadir/ f"N_{self.N}" / 'negX.npz'
        Zf = self.metadir/ f"N_{self.N}" / 'posZ.npz'
        t1=time.time()
        with open(self.fname / 'build.log', 'w') as fh: print(f"build\t{t1-t0}", file=fh)
        sp.save_npz(Xf, X)
        sp.save_npz(Zf, Z)
        print(f"build: done saving")
        print(f"build: Done with build.")

    def fit(self):
        svm = LinearSVC(random_state=0, penalty='l2', verbose=True, C=self.C, max_iter=1e5, fit_intercept=True, dual=True)
        datadir=Path(self.datadir)
        with gzip.open(self.metadir / 'Rname.pkl.gz', 'rb') as fh: Rname = pickle.load(fh)

        Xf = self.metadir/ f"N_{self.N}" / 'negX.npz'
        Zf = self.metadir/ f"N_{self.N}" / 'posZ.npz'
        data = sp.load_npz(Xf)
        data -= sp.load_npz(Zf)

        print(f"fit: data shape ={data.shape}")
        print(f"fit: Fitting SVM")
    
        n = data.shape[0]
        label = np.sign(nr.randn(n))
        data = sp.spdiags(label,0,n,n) @ data
        weight = np.ones(n); weight/=sum(weight); weight *= n/self.N
        with open(self.fname / 'fit.log', 'w') as fh:
            t0=time.time()
            svm.fit(data, label, sample_weight=weight)
            t1=time.time()
            print(f'intercept: {svm.intercept_}')
            print(f"fit: Score:        ", svm.score(data, label))
            print(f"fit\t{t1-t0}", file=fh)
            print(f"fit: Saving p,n svm to file")
            print(f"fit: Saving coefs to file")
            np.save(self.fname / 'coef.npy', svm.coef_.flatten())
        print("fit: Done")

    def score(self, dataname):
        metadir = self.metadir
        datadir = Path(self.datadir)
        fname = self.fname
        
        with gzip.open(metadir / 'E.pkl.gz', 'rb') as fh: E = pickle.load(fh)
        with gzip.open(metadir / 'R.pkl.gz', 'rb') as fh: R = pickle.load(fh)
        with gzip.open(metadir / 'E2I.pkl.gz', 'rb') as fh: E2I = pickle.load(fh)
        with gzip.open(metadir / 'R2I.pkl.gz', 'rb') as fh: R2I = pickle.load(fh)

        coef = np.load(self.fname / 'coef.npy', allow_pickle=True)
        k = E[0].size[0]
        nentities = len(E)
        nrelations = len(R)
        with codecs.open(datadir / dataname, encoding='utf-8') as f: raw_data = f.read()
        triples = raw_data.strip('\n').split('\n')
        triples = [row.strip().split('\t') for row in triples]
        shuffle(triples)

        with open(datadir / 'train.txt', 'r') as fh: 
            trainl = []
            for l in fh:
                trainl.append(l.strip().split('\t'))
        with open(datadir / 'valid.txt', 'r') as fh: 
            valid = []
            for l in fh:
                valid.append(l.strip().split('\t'))
        with open(datadir / 'test.txt', 'r') as fh: 
            test = []
            for l in fh:
                test.append(l.strip().split('\t'))

        # one-hot indices
        tripleset = set([(E[E2I[h]].I[0], R[R2I[r]].I[0], E[E2I[t]].I[0]) for h,r,t in trainl + valid + test])

        scoredataset = ScoreDataset(coef, triples, k, nrelations, nentities, E, R, E2I, R2I, tripleset)
        scoreloader = DataLoader(scoredataset, batch_size=10, collate_fn=lambda x: sum(x, []), num_workers=os.cpu_count()-1)
        X=[]
        with tqdm(total=len(scoreloader)) as pbar:
            for _X in scoreloader:
                X += _X
                pbar.update()
        MR  = sum(X)/len(X)+1
        MRR = sum(1./(_+1) for _ in X)/len(X)
        H10 = sum(_<10 for _ in X)  / len(X)
        H_3 = sum(_< 3 for _ in X)  / len(X)
        H_1 = sum(_< 1 for _ in X)  / len(X)
        with open(self.fname / 'score.log', 'w') as fh:
            print(f'MR\t{MR}', file=fh)
            print(f'MRR\t{MRR}', file=fh)
            print(f'H10\t{H10}', file=fh)
            print(f'H_3\t{H_3}', file=fh)
            print(f'H_1\t{H_1}', file=fh)
        print(f'MR\t{MR}')
        print(f'MRR\t{MRR}')
        print(f'H10\t{H10}')
        print(f'H_3\t{H_3}')
        print(f'H_1\t{H_1}')

class SVMDataset(Dataset):
    def __init__(self, kgdata, Ename, Rname, E2I, R2I, E, R, N, posneg):
        if posneg == 'pos' or posneg =='neg': self.posneg = posneg
        else: raise ValueError('Unknown posneg type')
        # knowledge graph for training
        self.kgdata = kgdata
        self.kgdataset = set(kgdata)
        # entities
        self.Ename = Ename
        self.Rname = Rname
        # relations
        self.E2I = E2I
        self.R2I = R2I
        self.E = E
        self.R = R
        # number of bad triples per good triple
        self.N = N
        # dimension of embedding
        k=len(Ename) + len(Rname)
        self.k=k

        _ij2idx = lambda i,j: k*(k-1)//2 - (k-i)*(k-i-1)//2 + j
        def ij2idx(i,j):
            if i<j: return _ij2idx(i,j)
            else: return _ij2idx(j,i)

        def hrtidx(h,r,t):
            _hhx = ij2idx(h, h)
            _rrx = ij2idx(r, r)
            _ttx = ij2idx(t, t)
            _hrx = ij2idx(h, r)
            _rtx = ij2idx(t, r)
            _htx = ij2idx(t, h)
            return (_hhx, _rrx, _ttx, _hrx, _rtx, _htx)

        self.hrtidx = hrtidx
        self.ij2idx = ij2idx

    def __len__(self):
        return len(self.kgdata) 

    def __getitem__(self, item):
        kgdata = self.kgdata
        kgdataset = self.kgdataset
        _h, _r, _t = kgdata[item]
        E=self.E
        R=self.R
        Ename = self.Ename
        Rname = self.Rname
        E2I=self.E2I
        R2I=self.R2I
        N=self.N
        k=self.k
        vals = [1, 1, 1, 2, -2, -2]
        idx = [0,]*6

        hidx, ridx, tidx = len(R)+E2I[_h], R2I[_r], len(R)+E2I[_t]
        if self.posneg == 'pos':
            _hhx, _rrx, _ttx, _hrx, _rtx, _htx = self.hrtidx(hidx, ridx, tidx)
            jdx = [_hhx, _rrx, _ttx, _hrx, _rtx, _htx]
            idx = [0,]*6
            Z = sp.csr_matrix((vals, (idx,jdx)), shape=(1,k*(k+1)//2))  
            return ([Z.copy() for _ in range(2*N)],)

        elif self.posneg == 'neg':
            X=[]
            nsuccess = 0
            while nsuccess < N:
                __h = choice(Ename)
                if (__h, _r, _t) not in kgdataset:
                    nsuccess += 1
                    _hhx, _rrx, _ttx, _hrx, _rtx, _htx = self.hrtidx(E2I[__h]+len(R), ridx, tidx)
                    jdx = [_hhx, _rrx, _ttx, _hrx, _rtx, _htx]
                    W = sp.csr_matrix((vals, (idx,jdx)), shape=(1,k*(k+1)//2))  
                    X.append(W)

            nsuccess=0
            while nsuccess < N:
                __t = choice(Ename)
                if (_h, _r, __t) not in kgdataset:
                    nsuccess += 1
                    _hhx, _rrx, _ttx, _hrx, _rtx, _htx = self.hrtidx(hidx, ridx, E2I[__t]+len(R))
                    jdx = [_hhx, _rrx, _ttx, _hrx, _rtx, _htx]
                    W = sp.csr_matrix((vals, (idx,jdx)), shape=(1,k*(k+1)//2))  
                    X.append(W)

            return (X,)
        else: raise ValueError('Uknown posneg type')

class ScoreDataset(Dataset):
    def __init__(self, coef, scoredata, k, nrelations, nentities, E, R, E2I, R2I, tripleset):
        self.data = scoredata
        self.nentities = nentities
        self.nrelations = nrelations
        self.k = k
        self.E = E
        self.R = R
        self.E2I = E2I
        self.R2I = R2I
        self.tripleset=tripleset
        self.coef = coef

        _ij2idx = lambda i,j: k*(k-1)//2 - (k-i)*(k-i-1)//2 + j
        def ij2idx(i,j):
            if i<j: return _ij2idx(i,j)
            else: return _ij2idx(j,i)

        def score(h,r,t,coef):
            _hhx = ij2idx(h, h)
            _rrx = ij2idx(r, r)
            _ttx = ij2idx(t, t)
            _hrx = ij2idx(h, r)
            _rtx = ij2idx(t, r)
            _htx = ij2idx(t, h)
            return coef[_hhx] + coef[_rrx] + coef[_ttx] + 2*(coef[_hrx] - coef[_rtx] - coef[_htx])
        self.score=score

    def __len__(self):
        return len(self.data) 

    def __getitem__(self, item):
        nentities = self.nentities
        nrelations = self.nrelations
        _h, _r, _t = self.data[item]
        coef = self.coef
        score=self.score
        R2I=self.R2I
        E2I=self.E2I
        k=self.k

        hidx = nrelations + E2I[_h]
        ridx = R2I[_r]
        tidx = nrelations + E2I[_t]
        ret = []
        score0=score(hidx, ridx, tidx, coef)

        Z = []
        for _ in range(nrelations, k):
            if (hidx, ridx, _) in self.tripleset: continue
            _s= score(hidx, ridx, _, coef)  
            if  _s <= score0: Z.append(_s)
        ret.append(len(Z))

        Z = []
        for _ in range(nrelations, k):
            if (_, ridx, tidx) in self.tripleset: continue
            _s= score(_, ridx, tidx, coef)  
            if  _s <= score0: Z.append(_s)
        ret.append(len(Z))

        return ret

if __name__=='__main__':
    import argparse
    import torch

    # torch.set_num_threads(os.cpu_count()-1)
    '''
    Usage: 
      mksvm data/FB15k-film-sample 15 1.0 --stage --build --fit --score test.txt

    '''

    parser = argparse.ArgumentParser()
    parser.add_argument('DATA')
    parser.add_argument('N', type=int)
    parser.add_argument('C', type=float, help="SVM regularization parameter C")
    parser.add_argument('--stage', action='store_true')
    parser.add_argument('--build', action='store_true')
    parser.add_argument('--fit', action='store_true')
    parser.add_argument('--outputdir', default='./output')
    parser.add_argument('--score', help="test.txt or valid.txt")
    args = parser.parse_args()

    klp = DataTorchSVMClass(args.N, args.DATA, args.C, outputdir=args.outputdir)

    if args.stage: klp.stage()
    if args.build: klp.build()
    if args.fit: klp.fit()
    if args.score: klp.score(args.score)
