amiriparian commited on
Commit
d995ff8
1 Parent(s): 1a7cb05

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -3
README.md CHANGED
@@ -55,15 +55,14 @@ Further details are available in the corresponding [**paper**](https://arxiv.org
55
  ```python
56
  import torch
57
  import torch.nn as nn
58
- from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
59
 
60
 
61
 
62
  # CONFIG and MODEL SETUP
63
  model_name = 'amiriparian/HuBERT-EmoSet'
64
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
65
- model = HubertForSequenceClassification.from_pretrained(model_name)
66
- model.classifier = nn.Linear(in_features=256,out_features=6)
67
 
68
  sampling_rate=16000
69
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
55
  ```python
56
  import torch
57
  import torch.nn as nn
58
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
59
 
60
 
61
 
62
  # CONFIG and MODEL SETUP
63
  model_name = 'amiriparian/HuBERT-EmoSet'
64
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
65
+ AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True)
 
66
 
67
  sampling_rate=16000
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")