import os
from typing import Any, List, Dict
from numbers import Number
import io
import torch
import torch.distributed as dist

def is_distributed_set() -> bool:
    return dist.is_available() and dist.is_initialized()

def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank()

def is_main_process():
    return get_rank() == 0

def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier()

def ensure_init_process_group():
    local_rank = int(os.environ["LOCAL_RANK"]) if 'LOCAL_RANK' in os.environ else -1
    world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    if world_size > 1 and not dist.is_initialized():
        assert local_rank is not None
        print("Init distributed training on local rank {}".format(local_rank))
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend='nccl'
        )
    return local_rank

def _broadcast_object(obj: Any, src_rank, device) -> Any:
    # see FairSeq/distributed/utils
    # this function is intended to use with non-tensor objects.
    if src_rank == get_rank():
        buffer = io.BytesIO()
        torch.save(obj, buffer)
        buffer = torch.ByteTensor(buffer.getbuffer()).to(device)
        length = torch.LongTensor([len(buffer)]).to(device)
        dist.broadcast(length, src=src_rank)
        dist.broadcast(buffer, src=src_rank)
    else:
        length = torch.LongTensor([0]).to(device)
        dist.broadcast(length, src=src_rank)
        buffer = torch.ByteTensor(int(length.item())).to(device)
        dist.broadcast(buffer, src=src_rank)
        buffer = io.BytesIO(buffer.cpu().numpy())
        obj = torch.load(buffer, map_location="cpu")
    return obj


def broadcast_objects(obj_list: List[Any], src_rank: int = 0) -> List[Any]:
    # list should have same length
    # dist.broadcast_object_list(obj_list, src=src_rank)  # somehow not working
    backend = torch.distributed.get_backend()
    if backend == torch.distributed.Backend.NCCL:
        device = torch.device("cuda")
    elif backend == torch.distributed.Backend.GLOO:
        device = torch.device("cpu")
    else:
        raise RuntimeError(f"Unsupported distributed backend: {backend}")

    out = []
    for obj in obj_list:
        out.append(_broadcast_object(obj, src_rank, device=device))
    return out

def all_reduce_scalar(value: Number, op: str = "sum") -> Number:
    """All-reduce single scalar value. NOT torch tensor."""
    if not is_distributed_set():
        return value

    op = op.lower()
    if (op == "sum") or (op == "mean"):
        dist_op = dist.ReduceOp.SUM
    elif op == "min":
        dist_op = dist.ReduceOp.MIN
    elif op == "max":
        dist_op = dist.ReduceOp.MAX
    elif op == "product":
        dist_op = dist.ReduceOp.PRODUCT
    else:
        raise RuntimeError(f"Invalid all_reduce_scalar op: {op}")

    backend = dist.get_backend()
    if backend == torch.distributed.Backend.NCCL:
        device = torch.device("cuda")
    elif backend == torch.distributed.Backend.GLOO:
        device = torch.device("cpu")
    else:
        raise RuntimeError(f"Unsupported distributed backend: {backend}")

    tensor = torch.tensor(value, device=device, requires_grad=False)
    dist.all_reduce(tensor, op=dist_op)
    if op == "mean":
        tensor /= get_world_size()
    ret = tensor.item()
    return ret


def all_reduce_tensor(tensor: torch.Tensor, op="sum", detach: bool = True) -> torch.Tensor:
    if not is_distributed_set():
        return tensor

    ret = tensor.clone()
    if detach:
        ret = ret.detach()
    if (op == "sum") or (op == "mean"):
        dist_op = dist.ReduceOp.SUM
    else:  # intentionally only support sum or mean
        raise RuntimeError(f"Invalid all_reduce_tensor op: {op}")

    dist.all_reduce(ret, op=dist_op)
    if op == "mean":
        ret /= get_world_size()
    return ret


def all_reduce_dict(result: Dict[str, Any], op="sum") -> Dict[str, Any]:
    # only accepts dictionary that key is string and value is either number or Tensor.
    new_result = {}
    for k, v in result.items():
        if isinstance(v, torch.Tensor):
            new_result[k] = all_reduce_tensor(v, op)
        elif isinstance(v, Number):
            new_result[k] = all_reduce_scalar(v, op)
        else:
            raise RuntimeError(f"Input dictionary for all_reduce_dict should only have "
                               f"either tensor or scalar as their values, got ({k}: {type(v)})")
    return new_result