frankenjoe commited on
Commit
bce2bd5
1 Parent(s): bb0dfd2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -91,7 +91,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
91
  hidden_states = outputs[0]
92
  hidden_states = torch.mean(hidden_states, dim=1)
93
  logits_age = self.age(hidden_states)
94
- logits_gender = self.gender(hidden_states)
95
 
96
  return hidden_states, logits_age, logits_gender
97
 
@@ -138,8 +138,8 @@ def process_func(
138
 
139
 
140
  print(process_func(signal, sampling_rate))
141
- # Age child female male
142
- # [[ 0.33793038 -0.17247453 -0.34937087 0.43983212]]
143
 
144
  print(process_func(signal, sampling_rate, embeddings=True))
145
  # Pooled hidden states of last transformer layer
 
91
  hidden_states = outputs[0]
92
  hidden_states = torch.mean(hidden_states, dim=1)
93
  logits_age = self.age(hidden_states)
94
+ logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
95
 
96
  return hidden_states, logits_age, logits_gender
97
 
 
138
 
139
 
140
  print(process_func(signal, sampling_rate))
141
+ # Age child female male
142
+ # [[ 0.33793038 0.2715511 0.2275236 0.5009253 ]]
143
 
144
  print(process_func(signal, sampling_rate, embeddings=True))
145
  # Pooled hidden states of last transformer layer