frankenjoe commited on
Commit
5cb94f8
1 Parent(s): e51b6f2

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +149 -0
README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - agender
5
+ - mozillacommonvoice
6
+ - timit
7
+ - voxceleb2
8
+ inference: true
9
+ tags:
10
+ - speech
11
+ - audio
12
+ - wav2vec2
13
+ - audio-classification
14
+ - age-recognition
15
+ - gender-recognition
16
+ license: cc-by-nc-sa-4.0
17
+ ---
18
+
19
+ # Model for Age and Gender Recognition based on Wav2vec 2.0 (24 layers)
20
+
21
+ The model expects a raw audio signal as input and outputs predictions
22
+ for age in a range of approximately 0...1 (0...100 years)
23
+ and gender expressing the probababilty for being child, female, or male.
24
+ In addition, it also provides the pooled states of the last transformer layer.
25
+ The model was created by fine-tuning [
26
+ Wav2Vec2-Large-Robust](https://huggingface.co/facebook/wav2vec2-large-robust)
27
+ on [aGender](https://paperswithcode.com/dataset/agender),
28
+ [Mozilla Common Voice](https://commonvoice.mozilla.org/),
29
+ [Timit](https://catalog.ldc.upenn.edu/LDC93s1) and
30
+ [Voxceleb 2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html).
31
+ For this version of the model we trained all 24 transformer layers.
32
+ An [ONNX](https://onnx.ai/") export of the model is available from
33
+ [doi:10.5281/zenodo.7761387](https://doi.org/10.5281/zenodo.7761387).
34
+ Further details are given in the associated [paper](https://arxiv.org/abs/2306.16962)
35
+ and [tutorial](https://github.com/audeering/w2v2-age-gender-how-to).
36
+
37
+ # Usage
38
+
39
+ ```python
40
+ import numpy as np
41
+ import torch
42
+ import torch.nn as nn
43
+ from transformers import Wav2Vec2Processor
44
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
45
+ Wav2Vec2Model,
46
+ Wav2Vec2PreTrainedModel,
47
+ )
48
+
49
+
50
+ class ModelHead(nn.Module):
51
+ r"""Classification head."""
52
+
53
+ def __init__(self, config, num_labels):
54
+
55
+ super().__init__()
56
+
57
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
58
+ self.dropout = nn.Dropout(config.final_dropout)
59
+ self.out_proj = nn.Linear(config.hidden_size, num_labels)
60
+
61
+ def forward(self, features, **kwargs):
62
+
63
+ x = features
64
+ x = self.dropout(x)
65
+ x = self.dense(x)
66
+ x = torch.tanh(x)
67
+ x = self.dropout(x)
68
+ x = self.out_proj(x)
69
+
70
+ return x
71
+
72
+
73
+ class AgeGenderModel(Wav2Vec2PreTrainedModel):
74
+ r"""Speech emotion classifier."""
75
+
76
+ def __init__(self, config):
77
+
78
+ super().__init__(config)
79
+
80
+ self.config = config
81
+ self.wav2vec2 = Wav2Vec2Model(config)
82
+ self.age = ModelHead(config, 1)
83
+ self.gender = ModelHead(config, 3)
84
+ self.init_weights()
85
+
86
+ def forward(
87
+ self,
88
+ input_values,
89
+ ):
90
+
91
+ outputs = self.wav2vec2(input_values)
92
+ hidden_states = outputs[0]
93
+ hidden_states = torch.mean(hidden_states, dim=1)
94
+ logits_age = self.age(hidden_states)
95
+ logits_gender = self.gender(hidden_states)
96
+
97
+ return hidden_states, logits_age, logits_gender
98
+
99
+
100
+
101
+ # load model from hub
102
+ device = 'cpu'
103
+ model_name = 'audeering/wav2vec2-large-robust-24-ft-age-gender'
104
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
105
+ model = AgeGenderModel.from_pretrained(model_name)
106
+
107
+ # dummy signal
108
+ sampling_rate = 16000
109
+ signal = np.zeros((1, sampling_rate), dtype=np.float32)
110
+
111
+
112
+ def process_func(
113
+ x: np.ndarray,
114
+ sampling_rate: int,
115
+ embeddings: bool = False,
116
+ ) -> np.ndarray:
117
+ r"""Predict age and gender or extract embeddings from raw audio signal."""
118
+
119
+ # run through processor to normalize signal
120
+ # always returns a batch, so we just get the first entry
121
+ # then we put it on the device
122
+ y = processor(x, sampling_rate=sampling_rate)
123
+ y = y['input_values'][0]
124
+ y = y.reshape(1, -1)
125
+ y = torch.from_numpy(y).to(device)
126
+
127
+ # run through model
128
+ with torch.no_grad():
129
+ y = model(y)
130
+ if embeddings:
131
+ y = y[0]
132
+ else:
133
+ y = torch.hstack([y[1], y[2]])
134
+
135
+ # convert to numpy
136
+ y = y.detach().cpu().numpy()
137
+
138
+ return y
139
+
140
+
141
+ print(process_func(signal, sampling_rate))
142
+ # Age child female male
143
+ # [[ 0.3079211 -1.6096017 -2.1094327 3.1461434]]
144
+
145
+ print(process_func(signal, sampling_rate, embeddings=True))
146
+ # Pooled hidden states of last transformer layer
147
+ # [[-0.00752167 0.0065819 -0.00746342 ... 0.00663632 0.00848748
148
+ # 0.00599211]]
149
+ ```