Rolv-Arild
commited on
Commit
•
9a7a0bd
1
Parent(s):
7fcdd24
Add NST+NPSC dataset script
Browse files- run.sh +1 -3
- run_speech_recognition_ctc.py +100 -67
run.sh
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
|
2 |
-
--dataset_name="NbAiLab/NST" \
|
3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
4 |
-
--hub_model_id="NbAiLab/wav2vec2-large-voxrex-nst" \
|
5 |
-
--dataset_config_name="no-close" \
|
6 |
--output_dir="./" \
|
7 |
--overwrite_output_dir \
|
8 |
--num_train_epochs="15" \
|
|
|
1 |
WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
|
|
|
2 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
3 |
+
--hub_model_id="NbAiLab/wav2vec2-large-voxrex-npsc-nst" \
|
|
|
4 |
--output_dir="./" \
|
5 |
--overwrite_output_dir \
|
6 |
--num_train_epochs="15" \
|
run_speech_recognition_ctc.py
CHANGED
@@ -47,13 +47,11 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
|
47 |
from transformers.utils import check_min_version
|
48 |
from transformers.utils.versions import require_version
|
49 |
|
50 |
-
|
51 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
52 |
check_min_version("4.16.0.dev0")
|
53 |
|
54 |
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
55 |
|
56 |
-
|
57 |
logger = logging.getLogger(__name__)
|
58 |
|
59 |
|
@@ -102,8 +100,8 @@ class ModelArguments:
|
|
102 |
default=0.05,
|
103 |
metadata={
|
104 |
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
105 |
-
|
106 |
-
|
107 |
},
|
108 |
)
|
109 |
mask_time_length: int = field(
|
@@ -114,7 +112,7 @@ class ModelArguments:
|
|
114 |
default=0.0,
|
115 |
metadata={
|
116 |
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
117 |
-
|
118 |
},
|
119 |
)
|
120 |
mask_feature_length: int = field(
|
@@ -129,6 +127,7 @@ class ModelArguments:
|
|
129 |
default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
|
130 |
)
|
131 |
|
|
|
132 |
@dataclass
|
133 |
class DataTrainingArguments:
|
134 |
"""
|
@@ -176,14 +175,14 @@ class DataTrainingArguments:
|
|
176 |
default=None,
|
177 |
metadata={
|
178 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
179 |
-
|
180 |
},
|
181 |
)
|
182 |
max_eval_samples: Optional[int] = field(
|
183 |
default=None,
|
184 |
metadata={
|
185 |
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
186 |
-
|
187 |
},
|
188 |
)
|
189 |
chars_to_ignore: Optional[List[str]] = list_field(
|
@@ -207,16 +206,16 @@ class DataTrainingArguments:
|
|
207 |
default=False,
|
208 |
metadata={
|
209 |
"help": "Whether to only do data preprocessing and skip training. "
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
},
|
214 |
)
|
215 |
use_auth_token: bool = field(
|
216 |
default=False,
|
217 |
metadata={
|
218 |
"help": "If :obj:`True`, will use the token generated when running"
|
219 |
-
|
220 |
},
|
221 |
)
|
222 |
unk_token: str = field(
|
@@ -235,9 +234,9 @@ class DataTrainingArguments:
|
|
235 |
default=None,
|
236 |
metadata={
|
237 |
"help": "The target language that should be used be"
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
},
|
242 |
)
|
243 |
|
@@ -303,10 +302,10 @@ class DataCollatorCTCWithPadding:
|
|
303 |
|
304 |
|
305 |
def create_vocabulary_from_data(
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
):
|
311 |
# Given training and test labels create vocabulary
|
312 |
def extract_all_chars(batch):
|
@@ -344,6 +343,85 @@ def create_vocabulary_from_data(
|
|
344 |
return vocab_dict
|
345 |
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
def main():
|
348 |
# See all possible arguments in src/transformers/training_args.py
|
349 |
# or by passing the --help flag to this script.
|
@@ -393,45 +471,10 @@ def main():
|
|
393 |
# Set seed before initializing model.
|
394 |
set_seed(training_args.seed)
|
395 |
|
396 |
-
# Pre-processing dataset
|
397 |
-
import re
|
398 |
-
|
399 |
-
def map_dataset(entry):
|
400 |
-
text = entry["text"].lower()
|
401 |
-
text = text.replace("(...Vær stille under dette opptaket...)", "")
|
402 |
-
text = re.sub('[áàâ]', 'a', text)
|
403 |
-
text = re.sub('[ä]', 'æ', text)
|
404 |
-
text = re.sub('[éèëê]', 'e', text)
|
405 |
-
text = re.sub('[íìïî]', 'i', text)
|
406 |
-
text = re.sub('[óòöô]', 'o', text)
|
407 |
-
text = re.sub('[ö]', 'ø', text)
|
408 |
-
text = re.sub('[ç]', 'c', text)
|
409 |
-
text = re.sub('[úùüû]', 'u', text)
|
410 |
-
# text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
|
411 |
-
text = re.sub('\s+', ' ', text)
|
412 |
-
return {"text": text}
|
413 |
-
|
414 |
-
|
415 |
-
def filter_dataset(entry):
|
416 |
-
if not (len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3):
|
417 |
-
return False # Too short
|
418 |
-
if re.match(entry["type"], "pIW|CA"):
|
419 |
-
return False # Spelling out words
|
420 |
-
return True
|
421 |
-
|
422 |
# 1. First, let's load the dataset
|
423 |
-
raw_datasets =
|
424 |
|
425 |
if training_args.do_train:
|
426 |
-
raw_datasets["train"] = load_dataset(
|
427 |
-
data_args.dataset_name,
|
428 |
-
data_args.dataset_config_name,
|
429 |
-
split=data_args.train_split_name,
|
430 |
-
use_auth_token=data_args.use_auth_token,
|
431 |
-
).shuffle()
|
432 |
-
raw_datasets["train"] = raw_datasets["train"].filter(filter_dataset)
|
433 |
-
raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
|
434 |
-
|
435 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
436 |
raise ValueError(
|
437 |
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
@@ -450,28 +493,18 @@ def main():
|
|
450 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
451 |
|
452 |
if training_args.do_eval:
|
453 |
-
raw_datasets["eval"] = load_dataset(
|
454 |
-
data_args.dataset_name,
|
455 |
-
data_args.dataset_config_name,
|
456 |
-
split=data_args.eval_split_name,
|
457 |
-
use_auth_token=data_args.use_auth_token,
|
458 |
-
).shuffle()
|
459 |
-
raw_datasets["eval"] = raw_datasets["eval"].filter(filter_dataset)
|
460 |
-
raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
|
461 |
-
|
462 |
if data_args.max_eval_samples is not None:
|
463 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
464 |
|
465 |
-
|
466 |
# 2. We remove some special characters from the datasets
|
467 |
# that make training complicated and do not help in transcribing the speech
|
468 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
469 |
# that could be easily picked up by the model
|
470 |
-
#chars_to_ignore_regex = (
|
471 |
# f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
472 |
-
#)
|
473 |
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
|
474 |
-
|
475 |
text_column_name = data_args.text_column_name
|
476 |
|
477 |
def remove_special_characters(batch):
|
|
|
47 |
from transformers.utils import check_min_version
|
48 |
from transformers.utils.versions import require_version
|
49 |
|
|
|
50 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
51 |
check_min_version("4.16.0.dev0")
|
52 |
|
53 |
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
54 |
|
|
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
57 |
|
|
|
100 |
default=0.05,
|
101 |
metadata={
|
102 |
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
103 |
+
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
104 |
+
"vectors will be masked along the time axis."
|
105 |
},
|
106 |
)
|
107 |
mask_time_length: int = field(
|
|
|
112 |
default=0.0,
|
113 |
metadata={
|
114 |
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
115 |
+
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
|
116 |
},
|
117 |
)
|
118 |
mask_feature_length: int = field(
|
|
|
127 |
default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
|
128 |
)
|
129 |
|
130 |
+
|
131 |
@dataclass
|
132 |
class DataTrainingArguments:
|
133 |
"""
|
|
|
175 |
default=None,
|
176 |
metadata={
|
177 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
178 |
+
"value if set."
|
179 |
},
|
180 |
)
|
181 |
max_eval_samples: Optional[int] = field(
|
182 |
default=None,
|
183 |
metadata={
|
184 |
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
185 |
+
"value if set."
|
186 |
},
|
187 |
)
|
188 |
chars_to_ignore: Optional[List[str]] = list_field(
|
|
|
206 |
default=False,
|
207 |
metadata={
|
208 |
"help": "Whether to only do data preprocessing and skip training. "
|
209 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
210 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
211 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
212 |
},
|
213 |
)
|
214 |
use_auth_token: bool = field(
|
215 |
default=False,
|
216 |
metadata={
|
217 |
"help": "If :obj:`True`, will use the token generated when running"
|
218 |
+
":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
|
219 |
},
|
220 |
)
|
221 |
unk_token: str = field(
|
|
|
234 |
default=None,
|
235 |
metadata={
|
236 |
"help": "The target language that should be used be"
|
237 |
+
" passed to the tokenizer for tokenization. Note that"
|
238 |
+
" this is only relevant if the model classifies the"
|
239 |
+
" input audio to a sequence of phoneme sequences."
|
240 |
},
|
241 |
)
|
242 |
|
|
|
302 |
|
303 |
|
304 |
def create_vocabulary_from_data(
|
305 |
+
datasets: DatasetDict,
|
306 |
+
word_delimiter_token: Optional[str] = None,
|
307 |
+
unk_token: Optional[str] = None,
|
308 |
+
pad_token: Optional[str] = None,
|
309 |
):
|
310 |
# Given training and test labels create vocabulary
|
311 |
def extract_all_chars(batch):
|
|
|
343 |
return vocab_dict
|
344 |
|
345 |
|
346 |
+
def make_dataset(seed=42):
|
347 |
+
# Pre-processing dataset
|
348 |
+
import re
|
349 |
+
|
350 |
+
def map_nst(entry):
|
351 |
+
text = entry["text"].lower()
|
352 |
+
text = text.replace("(...Vær stille under dette opptaket...)", "")
|
353 |
+
text = re.sub('[áàâ]', 'a', text)
|
354 |
+
text = re.sub('[ä]', 'æ', text)
|
355 |
+
text = re.sub('[éèëê]', 'e', text)
|
356 |
+
text = re.sub('[íìïî]', 'i', text)
|
357 |
+
text = re.sub('[óòöô]', 'o', text)
|
358 |
+
text = re.sub('[ö]', 'ø', text)
|
359 |
+
text = re.sub('[ç]', 'c', text)
|
360 |
+
text = re.sub('[úùüû]', 'u', text)
|
361 |
+
# text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
|
362 |
+
text = re.sub('\s+', ' ', text)
|
363 |
+
return {"text": text}
|
364 |
+
|
365 |
+
def filter_nst(entry):
|
366 |
+
if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
|
367 |
+
return False # Too short
|
368 |
+
if re.match(entry["type"], "pIW|CA"):
|
369 |
+
return False # Spelling out words
|
370 |
+
return True
|
371 |
+
|
372 |
+
def filter_npsc(entry):
|
373 |
+
# False if there are digits in the text
|
374 |
+
if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
|
375 |
+
return False # Too short
|
376 |
+
if re.search("\d", entry["text"]):
|
377 |
+
return False
|
378 |
+
return True
|
379 |
+
|
380 |
+
def map_npsc(entry):
|
381 |
+
batch = {"text": entry["text"].lower()}
|
382 |
+
batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
|
383 |
+
batch["text"] = re.sub('[ä]', 'æ', batch["text"])
|
384 |
+
batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
|
385 |
+
batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
|
386 |
+
batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
|
387 |
+
batch["text"] = re.sub('[ö]', 'ø', batch["text"])
|
388 |
+
batch["text"] = re.sub('[ç]', 'c', batch["text"])
|
389 |
+
batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
|
390 |
+
batch["text"] = re.sub('\s', ' ', batch["text"])
|
391 |
+
batch["text"] = re.sub('<ee>', 'eee', batch["text"])
|
392 |
+
batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
|
393 |
+
batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
|
394 |
+
batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
|
395 |
+
# batch["text"] = re.sub('<inaudible>', '?', batch["text"])
|
396 |
+
if "<" in batch["text"]:
|
397 |
+
raise ValueError(batch["text"])
|
398 |
+
return batch
|
399 |
+
|
400 |
+
nst = datasets.load_dataset("NbAiLab/NST", "no-close")
|
401 |
+
npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
|
402 |
+
# TODO NST_hesitate
|
403 |
+
|
404 |
+
split = len(npsc["train"]) / (len(npsc["train"]) + len(npsc["validation"])) # Use same train/val ratio as NPSC
|
405 |
+
nst_train = nst["train"].train_test_split(train_size=split, seed=seed)
|
406 |
+
nst["train"] = nst_train["train"]
|
407 |
+
nst["validation"] = nst_train["test"]
|
408 |
+
|
409 |
+
nst = nst.filter(filter_nst).map(map_nst).shuffle(seed=seed)
|
410 |
+
npsc = npsc.filter(filter_npsc).map(map_npsc).shuffle(seed=seed)
|
411 |
+
|
412 |
+
npsc_base = npsc.remove_columns([col for col in npsc["train"].column_names if col not in ["text", "audio"]])
|
413 |
+
nst_base = nst.remove_columns([col for col in nst["train"].column_names if col not in ["text", "audio"]])
|
414 |
+
|
415 |
+
combined = {}
|
416 |
+
for split in "train", "validation", "test":
|
417 |
+
probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
|
418 |
+
probs = (probs / probs.sum()).tolist()
|
419 |
+
comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
|
420 |
+
combined[split] = comb
|
421 |
+
|
422 |
+
return datasets.DatasetDict(**combined)
|
423 |
+
|
424 |
+
|
425 |
def main():
|
426 |
# See all possible arguments in src/transformers/training_args.py
|
427 |
# or by passing the --help flag to this script.
|
|
|
471 |
# Set seed before initializing model.
|
472 |
set_seed(training_args.seed)
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
# 1. First, let's load the dataset
|
475 |
+
raw_datasets = make_dataset(seed=training_args.seed)
|
476 |
|
477 |
if training_args.do_train:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
479 |
raise ValueError(
|
480 |
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
|
|
493 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
494 |
|
495 |
if training_args.do_eval:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
if data_args.max_eval_samples is not None:
|
497 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
498 |
|
|
|
499 |
# 2. We remove some special characters from the datasets
|
500 |
# that make training complicated and do not help in transcribing the speech
|
501 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
502 |
# that could be easily picked up by the model
|
503 |
+
# chars_to_ignore_regex = (
|
504 |
# f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
505 |
+
# )
|
506 |
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
|
507 |
+
|
508 |
text_column_name = data_args.text_column_name
|
509 |
|
510 |
def remove_special_characters(batch):
|