teticio commited on
Commit
903650a
1 Parent(s): 96e8f55

improve piepline loading

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +7 -5
audiodiffusion/__init__.py CHANGED
@@ -43,11 +43,13 @@ class AudioDiffusion:
43
  hop_length=hop_length,
44
  top_db=top_db)
45
  self.model_id = model_id
46
- try: # a bit hacky
47
- self.pipe = LatentAudioDiffusionPipeline.from_pretrained(self.model_id)
48
- except:
49
- self.pipe = AudioDiffusionPipeline.from_pretrained(self.model_id)
50
-
 
 
51
  if cuda:
52
  self.pipe.to("cuda")
53
  self.progress_bar = progress_bar or (lambda _: _)
 
43
  hop_length=hop_length,
44
  top_db=top_db)
45
  self.model_id = model_id
46
+ pipeline = {
47
+ 'LatentAudioDiffusionPipeline': LatentAudioDiffusionPipeline,
48
+ 'AudioDiffusionPipeline': AudioDiffusionPipeline
49
+ }.get(
50
+ DiffusionPipeline.get_config_dict(self.model_id)['_class_name'],
51
+ AudioDiffusionPipeline)
52
+ self.pipe = pipeline.from_pretrained(self.model_id)
53
  if cuda:
54
  self.pipe.to("cuda")
55
  self.progress_bar = progress_bar or (lambda _: _)