frankenjoe
commited on
Commit
•
bce2bd5
1
Parent(s):
bb0dfd2
Update README.md
Browse files
README.md
CHANGED
@@ -91,7 +91,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
|
|
91 |
hidden_states = outputs[0]
|
92 |
hidden_states = torch.mean(hidden_states, dim=1)
|
93 |
logits_age = self.age(hidden_states)
|
94 |
-
logits_gender = self.gender(hidden_states)
|
95 |
|
96 |
return hidden_states, logits_age, logits_gender
|
97 |
|
@@ -138,8 +138,8 @@ def process_func(
|
|
138 |
|
139 |
|
140 |
print(process_func(signal, sampling_rate))
|
141 |
-
# Age child
|
142 |
-
# [[ 0.33793038
|
143 |
|
144 |
print(process_func(signal, sampling_rate, embeddings=True))
|
145 |
# Pooled hidden states of last transformer layer
|
|
|
91 |
hidden_states = outputs[0]
|
92 |
hidden_states = torch.mean(hidden_states, dim=1)
|
93 |
logits_age = self.age(hidden_states)
|
94 |
+
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
95 |
|
96 |
return hidden_states, logits_age, logits_gender
|
97 |
|
|
|
138 |
|
139 |
|
140 |
print(process_func(signal, sampling_rate))
|
141 |
+
# Age child female male
|
142 |
+
# [[ 0.33793038 0.2715511 0.2275236 0.5009253 ]]
|
143 |
|
144 |
print(process_func(signal, sampling_rate, embeddings=True))
|
145 |
# Pooled hidden states of last transformer layer
|