flash_attn package makes it non-portable

#1
by bghira - opened

only runs on NVIDIA systems. not Apple, or AMD.

try to avoid :

model = AutoModelForCausalLM.from_pretrained(
EMU_HUB,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
trust_remote_code=True,
)

Sign up or log in to comment