#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequant.h"
#include "quantization_new/gemm/gemm_cuda.h"
#include "quantization_new/gemv/gemv_cuda.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("matquant4dequant_faster", &matquant4dequant_faster, "Dequantize 4-bit weight matrix to fp16, bf16 weight matrix, float16, bfloat16 faster version");
  m.def("matquant4dequant_faster_group", &matquant4dequant_faster_group, "Dequantize 4-bit weight matrix to fp16, bf16 weight matrix, float16, bfloat16 faster version");

  // new awq-style kernel
  m.def("gemm_forward_cuda_new", &gemm_forward_cuda_new, "New quantized GEMM kernel.");
  m.def("gemv_forward_cuda_new", &gemv_forward_cuda_new, "New quantized GEMV kernel.");
}
