import math
import torch
import torch.nn as nn
from typing import Union,Optional
from transformers import AutoModelForCausalLM

from types import SimpleNamespace
from collections import OrderedDict
import os
import sys
sys.path.append('../../')

from qeft.quant import lm_pack, make_quant, QuantLinear, QuantLinearReorder
from qeft.utils.misc import find_layers, interpret_dtype, parsing_layers, get_model_config

def model_multigpu(model, gpus, meta=None, model_name=None):
    assert meta is not None or model_name is not None, "at least one of 'meta' or 'model_name' argument must not None"
    
    if meta is None:
        meta = get_model_config(model_name)

    layers, pre_layers, post_layers = parsing_layers(model=model, meta=meta)
    
    for pre_layer in pre_layers:
        pre_layer = pre_layer.to(gpus[0])
    
    for post_layer in post_layers:
        post_layer = post_layer.to(gpus[0])
    
    model.lm_head = model.lm_head.to(gpus[0])

    class MoveModule(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
            self.dev = next(iter(self.module.parameters())).device
        def forward(self, *inp, **kwargs):
            inp = list(inp)
            if inp[0].device != self.dev:
                inp[0] = inp[0].to(self.dev)
            for key in kwargs:
                if hasattr(kwargs[key], 'device') and kwargs[key].device != self.dev:
                    kwargs[key] = kwargs[key].to(self.dev)
            tmp = self.module(*inp, **kwargs)
            return tmp

    pergpu = math.ceil(len(layers) / len(gpus))
    for i in range(len(layers) - 1):
        layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
    layers[-1] = MoveModule(layers[-1].to(gpus[0]))

    model.gpus = gpus

def get_hfmodel(model_name_or_path: str,
                dtype='auto',
                device_map='cpu',
                trust_remote_code=False,
                ):
    
    # for faster model loading
    org_kaiming_uniform = torch.nn.init.kaiming_uniform_
    org_uniform = torch.nn.init.uniform_
    org_normal = torch.nn.init.normal_
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, 
        torch_dtype=dtype,
        device_map=device_map, 
        trust_remote_code=trust_remote_code, 
    )
    torch.nn.init.kaiming_uniform_ = org_kaiming_uniform
    torch.nn.init.uniform_ = org_uniform
    torch.nn.init.normal_ = org_normal
    
    return model

def load_model(model_name_or_path,
                checkpoint_path,
                faster: Optional[bool] = True,
                dtype = None,
                device: Optional[Union[int, str, torch.device]] = 'cuda:0',
                cpu_load: Optional[bool] = True,
                ):
    
    if 'base_path' in torch.load(checkpoint_path).keys():
        return load_tunedmodel(model_name_or_path, checkpoint_path, faster, dtype, device, cpu_load)
    else:
        return load_qmodel(model_name_or_path, checkpoint_path, faster, dtype, device, cpu_load)
        
        

def hfmodel_to_qmodel(model, ckpt, faster: Optional[bool] = True,
                        device: Optional[Union[int, str, torch.device]] = 'cuda:0'):
    
    print(f"Loading model ....")
    
    if ckpt['packing']:
        make_quant(model, ckpt['quantinfos'])
        
        # # support old format
        # for n, v in ckpt['model_state_dict'].items():
        #     if n.endswith('oweight') and v.shape[0] > v.shape[1]:
        #         ckpt['model_state_dict'][n] = v.t().contiguous()
                
        model.load_state_dict(ckpt['model_state_dict'], strict=False)

        qlayers = find_layers(model, [QuantLinear, QuantLinearReorder])
        for name in qlayers:
            qlayers[name].set_kernel(faster)
    else:
        model.load_state_dict(ckpt['model_state_dict'], strict=False)
            
    if model.device == 'cpu' and device not in ['auto', 'cpu']:
        model = model.to(device)
    
    del ckpt
    import gc; gc.collect()
    torch.cuda.empty_cache()
    
    print("Done.")
    return model

def load_qmodel(model_name_or_path,
                  checkpoint_path,
                  faster: Optional[bool] = True,
                  dtype = None,
                  device: Optional[Union[int, str, torch.device]] = 'cuda:0',
                  cpu_load: Optional[bool] = True,
                  ):
    if not isinstance(device, torch.device) and device not in ['auto', 'cpu']:
        device = torch.device(device)
    device_map = 'cpu' if cpu_load else device
    ckpt = torch.load(checkpoint_path)

    if dtype == None:
        dtype = ckpt['dtype']
    else:
        dtype = interpret_dtype(dtype)
    try:
        import accelerate
        
        with accelerate.init_empty_weights():
            model = get_hfmodel(model_name_or_path,
                                dtype=dtype,
                                device_map=device_map)
    except:
        model = get_hfmodel(model_name_or_path,
                            dtype=dtype,
                            device_map=device_map)
        
    model = hfmodel_to_qmodel(model, ckpt, faster, device)
    
    ## added
    model = model.to(device=device)

    return model

def replace_oweight(model, ckpt_tuned):
    oweight_state_dict = ckpt_tuned['oweight_state_dict']
    
    for name, module in model.named_modules():
        if isinstance(module, QuantLinear):
            key = name.replace('.oweight', '')
            if key in oweight_state_dict:
                module.oweight.data = oweight_state_dict[key].data.to(module.oweight.dtype)
            
    del ckpt_tuned
    import gc; gc.collect()
    torch.cuda.empty_cache()
        
    return model

def load_tunedmodel(model_name_or_path,
                  checkpoint_path,
                  faster: Optional[bool] = True,
                  dtype=None,
                  device: Optional[Union[int, str, torch.device]] = 'cuda:0',
                  cpu_load: Optional[bool] = True,
                  ):
    
    if os.path.isdir(checkpoint_path):
        ckpt = torch.load(os.path.join(checkpoint_path, 'model.pth'))
    else:
        ckpt = torch.load(checkpoint_path)
        
    model = load_qmodel(model_name_or_path, ckpt['base_path'], faster, dtype, device, cpu_load)
    replace_oweight(model, ckpt)
        
    return model

def save_model(model, 
               quantizers,
               save_path,
               packing:bool,
               fake:bool):
    
    dtype = model.dtype
    wbits = list(quantizers.values())[0].bits
    group_size = list(quantizers.values())[0].group_size
    
    if fake:
        ckpt_path = save_path.replace('.pt', '_fake.pt')
        model_state_dict = model.state_dict()
        out_ids_dict = {name : quantizers[name].out_ids for name in quantizers}
        
        torch.save({
            'model_state_dict': model_state_dict,
            'out_ids_dict': out_ids_dict,
            'packing': False,
            'dtype' : dtype,
            'bits' : wbits,
            'group_size' : group_size,
            }, ckpt_path)

        print(f"fake quantized model is saved to {ckpt_path}")
    if packing:
        assert wbits in [3, 4], f"{wbits}bits is not supported."
        lm_pack(model, quantizers)
        model_state_dict = model.state_dict()
        from argparse import Namespace
        quantinfos = {n: Namespace(**{'bits':quantizers[n].bits, 
                          'sym':getattr(quantizers[n],'sym',False),
                          'group_size':getattr(quantizers[n],'group_size',-1), 
                          'n_out':getattr(quantizers[n],'n_out',0), 
                          'reorder':getattr(quantizers[n], 'reorder', False),
                          }) for n in quantizers}
        
        torch.save({
            'model_state_dict': model_state_dict,
            'quantinfos': quantinfos,
            'packing': True,
            'dtype' : dtype,
            'bits' : wbits, 
            'group_size' : group_size,
            }, save_path)
        print(f"{wbits}bit quantized packing model is saved to {save_path}")   
    
def save_tunedmodel(model,
                  base_path,
                  output_dir):
    model_state_dict = OrderedDict()
    for name, param in model.named_parameters():
        if 'oweight' in name:
            model_state_dict[name.replace('.oweight', '')] = param.data.to(model.dtype)
    
    save_path = os.path.join(output_dir, 'model.pth')
    torch.save({'oweight_state_dict': model_state_dict,
                'base_path' : os.path.abspath(base_path)}, save_path)
    
    print(f"fine-tuned model is saved to {save_path}.")
