FLAN-UL2 performance INT8 worse than BF16
#27
by
nelsonspbr
- opened
I am running inference following https://huggingface.co/google/flan-ul2#running-the-model. I tested both INT8 load_in_8bit
and BF16 torch_dtype=torch.bfloat16
methods. After running some experiments, INT8 is ~3x slower than BF16. For reference, these are the most executed kernels for INT8:
ampere_int32_i16832gemm_int8_256x128_ldg16_stages_64x3_nt
ampere_sgemm_128x32_tn
ampere_int32_i16832gemm_int8_128x128_ldg16_stages_64x3_nt
Is this "INT8" actually mixed precision? Would that start to explain why it is worse?