import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
import codecs
import copy
import json
import os
def cartesian_to_polar_matrix(m):
    x = np.zeros((m.shape[0], m.shape[1]))
    x[:,0] = np.linalg.norm(m, axis = 1)
    denom = np.reshape(np.power(m[:,-1], 2), (m.shape[0],1))
    for i, q in enumerate(reversed(range(m.shape[1]-1))):
        a = np.reshape(m[:,q], (m.shape[0],1))
        denom += np.power(a,2)
        x[:,-(1+i)] = np.arccos(a/np.sqrt(denom)).flatten()
        if i == 0:
            x[np.where(m[:,-1]<0),-1]= 2*np.pi -x[np.where(m[:,-1]<0),-1]
    x[:,1:] = x[:,1:]/np.pi  
    return x

def makeAngledistribution(samplepoint,dim):
    tmp = np.random.normal(0, 1, (samplepoint,dim))
    tmp = tmp/(np.tile(np.sqrt(np.sum(np.power(tmp,2), axis = 1)), (dim,1))).T
    angles = cartesian_to_polar_matrix(tmp)[:,1:]
    return angles 
    
if __name__ == "__main__":
    samplepoint = 100000
    
    max_gauss = 20
    if not os.path.exists('./figure_GMM'):
        os.mkdir('./figure_GMM')

    for dim in [2,3,5,10]:
        angles = makeAngledistribution(samplepoint,dim)
        Answer_dict = {}

        with codecs.open('dataset/angle_GMM_parameters_dim'+str(dim)+'.json', 'w','utf-8') as fo:
            fig = plt.figure()
            bestlb = np.zeros((dim-1,1))
            bestnumGauss = np.zeros((dim-1,1))
            for d in range(dim-1):
                if d == dim-1:
                    d = -1
                    x = np.arange(-0.2, 2.2, 0.00001)
                else:
                    x = np.arange(0, 1, 0.00001)

                Answer_dict[d] = {}
                a = angles[:, d].reshape(-1, 1)

                for ind, i in enumerate(range(1,max_gauss+1)):
                    gmm = GaussianMixture(n_components=i,
                    covariance_type ="spherical",
                    max_iter = 10000, tol = 1e-4
                    ).fit(a)

                    weight = gmm.weights_ 
                    mean = gmm.means_ 
                    sd = np.sqrt(gmm.covariances_) 
                    lb = gmm.lower_bound_
                    if ind == 0 or lb >bestlb[d]:
                        bestlb[d] = copy.deepcopy(lb)
                        bestnumGauss[d] = copy.deepcopy(i)
                        bestweight = copy.deepcopy(weight)
                        bestmean = copy.deepcopy(mean)
                        bestsd = copy.deepcopy(sd)

                #plot best 
                ys = np.zeros((int(bestnumGauss[d]),x.shape[0]))
                for q in range(int(bestnumGauss[d])):
                    ys[q,:] = bestweight[q] * norm.pdf(x, bestmean[q], bestsd[q])
                    plt.plot(x, ys[q,:])

                fig.savefig('./figure_GMM/GMM_'+str(dim)+'_'+str(d)+'.png')
                plt.clf()


                Answer_dict[d]['weight'] = list(bestweight)
                Answer_dict[d]['mean'] = [bestmean[i][0] for i in range(bestmean.shape[0])]
                Answer_dict[d]['sigma'] = list(bestsd)
                Answer_dict[d]['lower_bound'] = list(bestlb[d])
            json.dump(Answer_dict, fo, indent = 4)


                

        
