bwang0911 commited on
Commit
e59450e
1 Parent(s): 8bb104b

fix: make sure data and adapter on same device (#11)

Browse files

- fix: make sure data and adapter on same device (08577bc2e88cb6d2e7ffa9fb2c45ba7c16c02836)

Files changed (1) hide show
  1. custom_st.py +1 -2
custom_st.py CHANGED
@@ -55,7 +55,6 @@ class Transformer(nn.Module):
55
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
58
- self.device = next(self.auto_model.parameters()).device
59
 
60
  self._lora_adaptations = config.lora_adaptations
61
  if (
@@ -111,7 +110,7 @@ class Transformer(nn.Module):
111
  num_examples = len(features['input_ids'])
112
 
113
  adapter_mask = torch.full(
114
- (num_examples,), task_id, dtype=torch.int32, device=self.device
115
  )
116
 
117
  lora_arguments = (
 
55
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
 
58
 
59
  self._lora_adaptations = config.lora_adaptations
60
  if (
 
110
  num_examples = len(features['input_ids'])
111
 
112
  adapter_mask = torch.full(
113
+ (num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
114
  )
115
 
116
  lora_arguments = (