fix device and use auto model (#10)
Browse files- fix device and use auto model (23523c609254249482a222c3c9b67afdb6d0dbf2)
- custom_st.py +2 -1
custom_st.py
CHANGED
@@ -55,6 +55,7 @@ class Transformer(nn.Module):
|
|
55 |
|
56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
|
|
58 |
|
59 |
self._lora_adaptations = config.lora_adaptations
|
60 |
if (
|
@@ -116,7 +117,7 @@ class Transformer(nn.Module):
|
|
116 |
lora_arguments = (
|
117 |
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
118 |
)
|
119 |
-
output_states = self.forward(**features, **lora_arguments, return_dict=False)
|
120 |
output_tokens = output_states[0]
|
121 |
features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
|
122 |
return features
|
|
|
55 |
|
56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
58 |
+
self.device = next(self.auto_model.parameters()).device
|
59 |
|
60 |
self._lora_adaptations = config.lora_adaptations
|
61 |
if (
|
|
|
117 |
lora_arguments = (
|
118 |
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
119 |
)
|
120 |
+
output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
|
121 |
output_tokens = output_states[0]
|
122 |
features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
|
123 |
return features
|