ahassoun's picture
Upload 3018 files
ee6e328
|
raw
history blame
16.6 kB

์ž๋™ ์Œ์„ฑ ์ธ์‹[[automatic-speech-recognition]]

[[open-in-colab]]

์ž๋™ ์Œ์„ฑ ์ธ์‹(Automatic Speech Recognition, ASR)์€ ์Œ์„ฑ ์‹ ํ˜ธ๋ฅผ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์Œ์„ฑ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ํ…์ŠคํŠธ ์ถœ๋ ฅ์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค. Siri์™€ Alexa์™€ ๊ฐ™์€ ๊ฐ€์ƒ ์–ด์‹œ์Šคํ„ดํŠธ๋Š” ASR ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ผ์ƒ์ ์œผ๋กœ ์‚ฌ์šฉ์ž๋ฅผ ๋•๊ณ  ์žˆ์œผ๋ฉฐ, ํšŒ์˜ ์ค‘ ๋ผ์ด๋ธŒ ์บก์…˜ ๋ฐ ๋ฉ”๋ชจ ์ž‘์„ฑ๊ณผ ๊ฐ™์€ ์œ ์šฉํ•œ ์‚ฌ์šฉ์ž ์นœํ™”์  ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ๋„ ๋งŽ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ ์†Œ๊ฐœํ•  ๋‚ด์šฉ์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  1. MInDS-14 ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ Wav2Vec2๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์—ฌ ์˜ค๋””์˜ค๋ฅผ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ฏธ์„ธ ์กฐ์ •ํ•œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์„ค๋ช…ํ•˜๋Š” ์ž‘์—…์€ ๋‹ค์Œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์— ์˜ํ•ด ์ง€์›๋ฉ๋‹ˆ๋‹ค:

Data2VecAudio, Hubert, M-CTC-T, SEW, SEW-D, UniSpeech, UniSpeechSat, Wav2Vec2, Wav2Vec2-Conformer, WavLM

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install transformers datasets evaluate jiwer

Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜๋ฉด ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๊ณ  ์ปค๋ฎค๋‹ˆํ‹ฐ์— ๊ณต์œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ† ํฐ์„ ์ž…๋ ฅํ•˜์—ฌ ๋กœ๊ทธ์ธํ•˜์„ธ์š”.

>>> from huggingface_hub import notebook_login

>>> notebook_login()

MInDS-14 ๋ฐ์ดํ„ฐ ์„ธํŠธ ๊ฐ€์ ธ์˜ค๊ธฐ[[load-minds-14-dataset]]

๋จผ์ €, ๐Ÿค— Datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ MInDS-14 ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ผ๋ถ€๋ถ„์„ ๊ฐ€์ ธ์˜ค์„ธ์š”. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ๋Œ€ํ•œ ํ›ˆ๋ จ์— ์‹œ๊ฐ„์„ ๋“ค์ด๊ธฐ ์ „์— ๋ชจ๋“  ๊ฒƒ์ด ์ž‘๋™ํ•˜๋Š”์ง€ ์‹คํ—˜ํ•˜๊ณ  ๊ฒ€์ฆํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from datasets import load_dataset, Audio

>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train[:100]")

[~Dataset.train_test_split] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ train์„ ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ„์„ธ์š”:

>>> minds = minds.train_test_split(test_size=0.2)

๊ทธ๋ฆฌ๊ณ  ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ํ™•์ธํ•˜์„ธ์š”:

>>> minds
DatasetDict({
    train: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 16
    })
    test: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 4
    })
})

๋ฐ์ดํ„ฐ ์„ธํŠธ์—๋Š” lang_id์™€ english_transcription๊ณผ ๊ฐ™์€ ์œ ์šฉํ•œ ์ •๋ณด๊ฐ€ ๋งŽ์ด ํฌํ•จ๋˜์–ด ์žˆ์ง€๋งŒ, ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” audio์™€ transcription์— ์ดˆ์ ์„ ๋งž์ถœ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์—ด์€ [~datasets.Dataset.remove_columns] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ œ๊ฑฐํ•˜์„ธ์š”:

>>> minds = minds.remove_columns(["english_transcription", "intent_class", "lang_id"])

์˜ˆ์‹œ๋ฅผ ๋‹ค์‹œ ํ•œ๋ฒˆ ํ™•์ธํ•ด๋ณด์„ธ์š”:

>>> minds["train"][0]
{'audio': {'array': array([-0.00024414,  0.        ,  0.        , ...,  0.00024414,
          0.00024414,  0.00024414], dtype=float32),
  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
  'sampling_rate': 8000},
 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
 'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"}

๋‘ ๊ฐœ์˜ ํ•„๋“œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค:

  • audio: ์˜ค๋””์˜ค ํŒŒ์ผ์„ ๊ฐ€์ ธ์˜ค๊ณ  ๋ฆฌ์ƒ˜ํ”Œ๋งํ•˜๊ธฐ ์œ„ํ•ด ํ˜ธ์ถœํ•ด์•ผ ํ•˜๋Š” ์Œ์„ฑ ์‹ ํ˜ธ์˜ 1์ฐจ์› array(๋ฐฐ์—ด)
  • transcription: ๋ชฉํ‘œ ํ…์ŠคํŠธ

์ „์ฒ˜๋ฆฌ[[preprocess]]

๋‹ค์Œ์œผ๋กœ ์˜ค๋””์˜ค ์‹ ํ˜ธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ Wav2Vec2 ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค:

>>> from transformers import AutoProcessor

>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")

MInDS-14 ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ๋Š” 8000kHz์ด๋ฏ€๋กœ(๋ฐ์ดํ„ฐ ์„ธํŠธ ์นด๋“œ์—์„œ ํ™•์ธ), ์‚ฌ์ „ ํ›ˆ๋ จ๋œ Wav2Vec2 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ 16000kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ๋งํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
>>> minds["train"][0]
{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ...,
          2.78103951e-04,  2.38446111e-04,  1.18740834e-04], dtype=float32),
  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
  'sampling_rate': 16000},
 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
 'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"}

์œ„์˜ 'transcription'์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ์ด ํ…์ŠคํŠธ๋Š” ๋Œ€๋ฌธ์ž์™€ ์†Œ๋ฌธ์ž๊ฐ€ ์„ž์—ฌ ์žˆ์Šต๋‹ˆ๋‹ค. Wav2Vec2 ํ† ํฌ๋‚˜์ด์ €๋Š” ๋Œ€๋ฌธ์ž ๋ฌธ์ž์— ๋Œ€ํ•ด์„œ๋งŒ ํ›ˆ๋ จ๋˜์–ด ์žˆ์œผ๋ฏ€๋กœ ํ…์ŠคํŠธ๊ฐ€ ํ† ํฌ๋‚˜์ด์ €์˜ ์–ดํœ˜์™€ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

>>> def uppercase(example):
...     return {"transcription": example["transcription"].upper()}


>>> minds = minds.map(uppercase)

์ด์ œ ๋‹ค์Œ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

  1. audio ์—ด์„ ํ˜ธ์ถœํ•˜์—ฌ ์˜ค๋””์˜ค ํŒŒ์ผ์„ ๊ฐ€์ ธ์˜ค๊ณ  ๋ฆฌ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.
  2. ์˜ค๋””์˜ค ํŒŒ์ผ์—์„œ input_values๋ฅผ ์ถ”์ถœํ•˜๊ณ  ํ”„๋กœ์„ธ์„œ๋กœ transcription ์—ด์„ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
>>> def prepare_dataset(batch):
...     audio = batch["audio"]
...     batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"])
...     batch["input_length"] = len(batch["input_values"][0])
...     return batch

์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋ ค๋ฉด ๐Ÿค— Datasets [~datasets.Dataset.map] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. num_proc ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ”„๋กœ์„ธ์Šค ์ˆ˜๋ฅผ ๋Š˜๋ฆฌ๋ฉด map์˜ ์†๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. [~datasets.Dataset.remove_columns] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•„์š”ํ•˜์ง€ ์•Š์€ ์—ด์„ ์ œ๊ฑฐํ•˜์„ธ์š”:

>>> encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names["train"], num_proc=4)

๐Ÿค— Transformers์—๋Š” ์ž๋™ ์Œ์„ฑ ์ธ์‹์šฉ ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ์˜ˆ์ œ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜๋ ค๋ฉด [DataCollatorWithPadding]์„ ์กฐ์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ๋Š” ํ…์ŠคํŠธ์™€ ๋ ˆ์ด๋ธ”์„ ๋ฐฐ์น˜์—์„œ ๊ฐ€์žฅ ๊ธด ์š”์†Œ์˜ ๊ธธ์ด์— ๋™์ ์œผ๋กœ ํŒจ๋”ฉํ•˜์—ฌ ๊ธธ์ด๋ฅผ ๊ท ์ผํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. tokenizer ํ•จ์ˆ˜์—์„œ padding=True๋ฅผ ์„ค์ •ํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ํŒจ๋”ฉํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ๋™์  ํŒจ๋”ฉ์ด ๋” ํšจ์œจ์ ์ž…๋‹ˆ๋‹ค.

๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ์™€ ๋‹ฌ๋ฆฌ ์ด ํŠน์ • ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ๋Š” input_values์™€ labels์— ๋Œ€ํ•ด ๋‹ค๋ฅธ ํŒจ๋”ฉ ๋ฐฉ๋ฒ•์„ ์ ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> import torch

>>> from dataclasses import dataclass, field
>>> from typing import Any, Dict, List, Optional, Union


>>> @dataclass
... class DataCollatorCTCWithPadding:
...     processor: AutoProcessor
...     padding: Union[bool, str] = "longest"

...     def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
...         # ์ž…๋ ฅ๊ณผ ๋ ˆ์ด๋ธ”์„ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค
...         # ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅด๊ณ , ๊ฐ๊ฐ ๋‹ค๋ฅธ ํŒจ๋”ฉ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค
...         input_features = [{"input_values": feature["input_values"][0]} for feature in features]
...         label_features = [{"input_ids": feature["labels"]} for feature in features]

...         batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")

...         labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")

...         # ํŒจ๋”ฉ์— ๋Œ€ํ•ด ์†์‹ค์„ ์ ์šฉํ•˜์ง€ ์•Š๋„๋ก -100์œผ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค
...         labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

...         batch["labels"] = labels

...         return batch

์ด์ œ DataCollatorForCTCWithPadding์„ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค:

>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")

ํ‰๊ฐ€ํ•˜๊ธฐ[[evaluate]]

ํ›ˆ๋ จ ์ค‘์— ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ํฌํ•จํ•˜๋ฉด ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. ๐Ÿค— Evaluate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ํ‰๊ฐ€ ๋ฐฉ๋ฒ•์„ ๋น ๋ฅด๊ฒŒ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ž‘์—…์—์„œ๋Š” ๋‹จ์–ด ์˜ค๋ฅ˜์œจ(Word Error Rate, WER) ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. (ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ๐Ÿค— Evaluate ๋‘˜๋Ÿฌ๋ณด๊ธฐ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”):

>>> import evaluate

>>> wer = evaluate.load("wer")

๊ทธ๋Ÿฐ ๋‹ค์Œ ์˜ˆ์ธก๊ฐ’๊ณผ ๋ ˆ์ด๋ธ”์„ [~evaluate.EvaluationModule.compute]์— ์ „๋‹ฌํ•˜์—ฌ WER์„ ๊ณ„์‚ฐํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค:

>>> import numpy as np


>>> def compute_metrics(pred):
...     pred_logits = pred.predictions
...     pred_ids = np.argmax(pred_logits, axis=-1)

...     pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

...     pred_str = processor.batch_decode(pred_ids)
...     label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

...     wer = wer.compute(predictions=pred_str, references=label_str)

...     return {"wer": wer}

์ด์ œ compute_metrics ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์œผ๋ฉฐ, ํ›ˆ๋ จ์„ ์„ค์ •ํ•  ๋•Œ ์ด ํ•จ์ˆ˜๋กœ ๋˜๋Œ์•„์˜ฌ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

ํ›ˆ๋ จํ•˜๊ธฐ[[train]]

[Trainer]๋กœ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์ด ์ต์ˆ™ํ•˜์ง€ ์•Š๋‹ค๋ฉด, ์—ฌ๊ธฐ์—์„œ ๊ธฐ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์„ ํ™•์ธํ•ด๋ณด์„ธ์š”!

์ด์ œ ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! [AutoModelForCTC]๋กœ Wav2Vec2๋ฅผ ๊ฐ€์ ธ์˜ค์„ธ์š”. ctc_loss_reduction ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ CTC ์†์‹ค์— ์ ์šฉํ•  ์ถ•์†Œ(reduction) ๋ฐฉ๋ฒ•์„ ์ง€์ •ํ•˜์„ธ์š”. ๊ธฐ๋ณธ๊ฐ’์ธ ํ•ฉ๊ณ„ ๋Œ€์‹  ํ‰๊ท ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹์€ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค:

>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer

>>> model = AutoModelForCTC.from_pretrained(
...     "facebook/wav2vec2-base",
...     ctc_loss_reduction="mean",
...     pad_token_id=processor.tokenizer.pad_token_id,
... )

์ด์ œ ์„ธ ๋‹จ๊ณ„๋งŒ ๋‚จ์•˜์Šต๋‹ˆ๋‹ค:

  1. [TrainingArguments]์—์„œ ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •์˜ํ•˜์„ธ์š”. output_dir์€ ๋ชจ๋ธ์„ ์ €์žฅํ•  ๊ฒฝ๋กœ๋ฅผ ์ง€์ •ํ•˜๋Š” ์œ ์ผํ•œ ํ•„์ˆ˜ ๋งค๊ฐœ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค. push_to_hub=True๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋ชจ๋ธ์„ Hub์— ์—…๋กœ๋“œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๋ ค๋ฉด Hugging Face์— ๋กœ๊ทธ์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค). [Trainer]๋Š” ๊ฐ ์—ํญ๋งˆ๋‹ค WER์„ ํ‰๊ฐ€ํ•˜๊ณ  ํ›ˆ๋ จ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ ์„ธํŠธ, ํ† ํฌ๋‚˜์ด์ €, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ, compute_metrics ํ•จ์ˆ˜์™€ ํ•จ๊ป˜ [Trainer]์— ํ›ˆ๋ จ ์ธ์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜์„ธ์š”.
  3. [~Trainer.train]์„ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”.
>>> training_args = TrainingArguments(
...     output_dir="my_awesome_asr_mind_model",
...     per_device_train_batch_size=8,
...     gradient_accumulation_steps=2,
...     learning_rate=1e-5,
...     warmup_steps=500,
...     max_steps=2000,
...     gradient_checkpointing=True,
...     fp16=True,
...     group_by_length=True,
...     evaluation_strategy="steps",
...     per_device_eval_batch_size=8,
...     save_steps=1000,
...     eval_steps=1000,
...     logging_steps=25,
...     load_best_model_at_end=True,
...     metric_for_best_model="wer",
...     greater_is_better=False,
...     push_to_hub=True,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=encoded_minds["train"],
...     eval_dataset=encoded_minds["test"],
...     tokenizer=processor.feature_extractor,
...     data_collator=data_collator,
...     compute_metrics=compute_metrics,
... )

>>> trainer.train()

ํ›ˆ๋ จ์ด ์™„๋ฃŒ๋˜๋ฉด ๋ชจ๋‘๊ฐ€ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก [~transformers.Trainer.push_to_hub] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ Hub์— ๊ณต์œ ํ•˜์„ธ์š”:

>>> trainer.push_to_hub()

์ž๋™ ์Œ์„ฑ ์ธ์‹์„ ์œ„ํ•ด ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๋” ์ž์„ธํ•œ ์˜ˆ์ œ๋Š” ์˜์–ด ์ž๋™ ์Œ์„ฑ ์ธ์‹์„ ์œ„ํ•œ ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ์™€ ๋‹ค๊ตญ์–ด ์ž๋™ ์Œ์„ฑ ์ธ์‹์„ ์œ„ํ•œ ํฌ์ŠคํŠธ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์ถ”๋ก ํ•˜๊ธฐ[[inference]]

์ข‹์•„์š”, ์ด์ œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ–ˆ์œผ๋‹ˆ ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์˜ค๋””์˜ค ํŒŒ์ผ์„ ๊ฐ€์ ธ์˜ค์„ธ์š”. ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์˜ค๋””์˜ค ํŒŒ์ผ์˜ ์ƒ˜ํ”Œ๋ง ๋น„์œจ์„ ๋ชจ๋ธ์˜ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ์— ๋งž๊ฒŒ ๋ฆฌ์ƒ˜ํ”Œ๋งํ•˜๋Š” ๊ฒƒ์„ ์žŠ์ง€ ๋งˆ์„ธ์š”!

>>> from datasets import load_dataset, Audio

>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train")
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> audio_file = dataset[0]["audio"]["path"]

์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์‹œํ—˜ํ•ด๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [pipeline]์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™ ์Œ์„ฑ ์ธ์‹์„ ์œ„ํ•œ pipeline์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์˜ค๋””์˜ค ํŒŒ์ผ์„ ์ „๋‹ฌํ•˜์„ธ์š”:

>>> from transformers import pipeline

>>> transcriber = pipeline("automatic-speech-recognition", model="stevhliu/my_awesome_asr_minds_model")
>>> transcriber(audio_file)
{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'}

ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜๋œ ๊ฒฐ๊ณผ๊ฐ€ ๊ฝค ๊ดœ์ฐฎ์ง€๋งŒ ๋” ์ข‹์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค! ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์œผ๋ ค๋ฉด ๋” ๋งŽ์€ ์˜ˆ์ œ๋กœ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”!

pipeline์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ์žฌํ˜„ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

์˜ค๋””์˜ค ํŒŒ์ผ๊ณผ ํ…์ŠคํŠธ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ณ  PyTorch ํ…์„œ๋กœ `input`์„ ๋ฐ˜ํ™˜ํ•  ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค์„ธ์š”:
>>> from transformers import AutoProcessor

>>> processor = AutoProcessor.from_pretrained("stevhliu/my_awesome_asr_mind_model")
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

์ž…๋ ฅ์„ ๋ชจ๋ธ์— ์ „๋‹ฌํ•˜๊ณ  ๋กœ์ง“์„ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”:

>>> from transformers import AutoModelForCTC

>>> model = AutoModelForCTC.from_pretrained("stevhliu/my_awesome_asr_mind_model")
>>> with torch.no_grad():
...     logits = model(**inputs).logits

๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์˜ input_ids๋ฅผ ์˜ˆ์ธกํ•˜๊ณ , ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ธก๋œ input_ids๋ฅผ ๋‹ค์‹œ ํ…์ŠคํŠธ๋กœ ๋””์ฝ”๋”ฉํ•˜์„ธ์š”:

>>> import torch

>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription
['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER']