chrlukas commited on
Commit
f3c82e8
1 Parent(s): ec74a13

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -2
README.md CHANGED
@@ -37,7 +37,7 @@ It achieves Unweighed Average Recall (UAR) values of .8001 and .8084 on the deve
37
  - **Paper [optional]:** [More Information Needed]
38
 
39
 
40
- ## Uses
41
 
42
  The following snippet illustrates the usage of the model.
43
  ```python
@@ -49,13 +49,15 @@ import librosa
49
  checkpoint = "chrlukas/flattery_prediction_speech"
50
  processor = AutoFeatureExtractor.from_pretrained(checkpoint)
51
  model = Wav2Vec2ForSequenceClassification.from_pretrained(checkpoint)
 
52
 
53
  # predict flattery in a sentence
54
  example_file = 'example.wav'
55
  # audio must be resampled to 16Hz
56
  y, _ = librosa.load(test_file, sr=16000)
57
  inp = processor(y, sampling_rate=16000, return_tensors='pt')
58
- logits = model(**inp).logits
 
59
  prediction = sigmoid(logits).item()
60
  flattery = prediction >= 0.5
61
  print(f'Flattery detected? {flattery}')
 
37
  - **Paper [optional]:** [More Information Needed]
38
 
39
 
40
+ ## Usage
41
 
42
  The following snippet illustrates the usage of the model.
43
  ```python
 
49
  checkpoint = "chrlukas/flattery_prediction_speech"
50
  processor = AutoFeatureExtractor.from_pretrained(checkpoint)
51
  model = Wav2Vec2ForSequenceClassification.from_pretrained(checkpoint)
52
+ model.eval()
53
 
54
  # predict flattery in a sentence
55
  example_file = 'example.wav'
56
  # audio must be resampled to 16Hz
57
  y, _ = librosa.load(test_file, sr=16000)
58
  inp = processor(y, sampling_rate=16000, return_tensors='pt')
59
+ with torch.no_grad():
60
+ logits = model(**inp).logits
61
  prediction = sigmoid(logits).item()
62
  flattery = prediction >= 0.5
63
  print(f'Flattery detected? {flattery}')