Support accelerate for GLM
#2
by
larrylawl
- opened
Support accelerate for GLM. Example code to run in 8 bit inference:
from transformers import AutoModelForSeq2SeqLM
device_map={'glm.word_embeddings': 0,
'glm.transformer.embedding_dropout': 0,
'glm.transformer.position_embeddings': 0,
'glm.transformer.block_position_embeddings': 0,
'glm.transformer.layers.0': 0,
'glm.transformer.layers.1': 0,
'glm.transformer.layers.2': 0,
'glm.transformer.layers.3': 0,
'glm.transformer.layers.4': 0,
'glm.transformer.layers.5': 0,
'glm.transformer.layers.6': 0,
'glm.transformer.layers.7': 0,
'glm.transformer.layers.8': 0,
'glm.transformer.layers.9': 0,
'glm.transformer.layers.10': 0,
'glm.transformer.layers.11': 0,
'glm.transformer.layers.12': 0,
'glm.transformer.layers.13': 0,
'glm.transformer.layers.14': 0,
'glm.transformer.layers.15': 0,
'glm.transformer.layers.16': 0,
'glm.transformer.layers.17': 0,
'glm.transformer.layers.18': 0,
'glm.transformer.layers.19': 0,
'glm.transformer.layers.20': 0,
'glm.transformer.layers.21': 0,
'glm.transformer.layers.22': 0,
'glm.transformer.layers.23': 0,
'glm.transformer.layers.24': 0,
'glm.transformer.layers.25': 0,
'glm.transformer.layers.26': 0,
'glm.transformer.layers.27': 0,
'glm.transformer.layers.28': 0,
'glm.transformer.layers.29': 0,
'glm.transformer.layers.30': 0,
'glm.transformer.layers.31': 0,
'glm.transformer.layers.32': 0,
'glm.transformer.layers.33': 0,
'glm.transformer.layers.34': 0,
'glm.transformer.layers.35': 0,
'glm.transformer.layers.36': 0,
'glm.transformer.layers.37': 0,
'glm.transformer.layers.38': 0,
'glm.transformer.layers.39': 0,
'glm.transformer.layers.40': 0,
'glm.transformer.layers.41': 0,
'glm.transformer.layers.42': 0,
'glm.transformer.layers.43': 0,
'glm.transformer.layers.44': 0,
'glm.transformer.layers.45': 0,
'glm.transformer.layers.46': 0,
'glm.transformer.layers.47': 0,
'glm.transformer.final_layernorm': 0}
# ours
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
trust_remote_code=True,
revision="6adb492",
device_map=device_map,
load_in_8bit=True,
)
model.eval()
larrylawl
changed pull request title from
[WIP] Support accelerate for GLM
to Support accelerate for GLM