#pragma once
#include <torch/all.h>

const int BLOCKWIDTH  = 256;
const int BLOCKHEIGHT =  24;
const int BLOCKHEIGHT4B =  32;

template <typename T1, typename T2>
__global__ void MatQuant4DequantKernelFaster(
    const      int* __restrict__ mat,
                T1* __restrict__ out,
    const       T1* __restrict__ scales,
    const  uint8_t* __restrict__ zeros,
    int height,
    int width
);

template <typename T1, typename T2>
__global__ void MatQuant4DequantKernelFasterGroup(
    const      int* __restrict__ mat,
                T1* __restrict__ out,
    const       T1* __restrict__ scales,
    const  uint8_t* __restrict__ zeros,
    int height,
    int width
);

void matquant4dequant_faster(
  torch::Tensor mat,
  torch::Tensor out,
  torch::Tensor scales,
  torch::Tensor zeros
);

void matquant4dequant_faster_group(
  torch::Tensor mat,
  torch::Tensor out,
  torch::Tensor scales,
  torch::Tensor zeros
);
