Spaces:
Paused
์๋ ์์ฑ ์ธ์[[automatic-speech-recognition]]
[[open-in-colab]]
์๋ ์์ฑ ์ธ์(Automatic Speech Recognition, ASR)์ ์์ฑ ์ ํธ๋ฅผ ํ ์คํธ๋ก ๋ณํํ์ฌ ์์ฑ ์ ๋ ฅ ์ํ์ค๋ฅผ ํ ์คํธ ์ถ๋ ฅ์ ๋งคํํฉ๋๋ค. Siri์ Alexa์ ๊ฐ์ ๊ฐ์ ์ด์์คํดํธ๋ ASR ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ผ์์ ์ผ๋ก ์ฌ์ฉ์๋ฅผ ๋๊ณ ์์ผ๋ฉฐ, ํ์ ์ค ๋ผ์ด๋ธ ์บก์ ๋ฐ ๋ฉ๋ชจ ์์ฑ๊ณผ ๊ฐ์ ์ ์ฉํ ์ฌ์ฉ์ ์นํ์ ์์ฉ ํ๋ก๊ทธ๋จ๋ ๋ง์ด ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค:
- MInDS-14 ๋ฐ์ดํฐ ์ธํธ์์ Wav2Vec2๋ฅผ ๋ฏธ์ธ ์กฐ์ ํ์ฌ ์ค๋์ค๋ฅผ ํ ์คํธ๋ก ๋ณํํฉ๋๋ค.
- ๋ฏธ์ธ ์กฐ์ ํ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํฉ๋๋ค.
Data2VecAudio, Hubert, M-CTC-T, SEW, SEW-D, UniSpeech, UniSpeechSat, Wav2Vec2, Wav2Vec2-Conformer, WavLM
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
pip install transformers datasets evaluate jiwer
Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์.
>>> from huggingface_hub import notebook_login
>>> notebook_login()
MInDS-14 ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-minds-14-dataset]]
๋จผ์ , ๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ MInDS-14 ๋ฐ์ดํฐ ์ธํธ์ ์ผ๋ถ๋ถ์ ๊ฐ์ ธ์ค์ธ์. ์ด๋ ๊ฒ ํ๋ฉด ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ํ๋ จ์ ์๊ฐ์ ๋ค์ด๊ธฐ ์ ์ ๋ชจ๋ ๊ฒ์ด ์๋ํ๋์ง ์คํํ๊ณ ๊ฒ์ฆํ ์ ์์ต๋๋ค.
>>> from datasets import load_dataset, Audio
>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train[:100]")
[~Dataset.train_test_split
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ train
์ ํ๋ จ ์ธํธ์ ํ
์คํธ ์ธํธ๋ก ๋๋์ธ์:
>>> minds = minds.train_test_split(test_size=0.2)
๊ทธ๋ฆฌ๊ณ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ์ธํ์ธ์:
>>> minds
DatasetDict({
train: Dataset({
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
num_rows: 16
})
test: Dataset({
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
num_rows: 4
})
})
๋ฐ์ดํฐ ์ธํธ์๋ lang_id
์ english_transcription
๊ณผ ๊ฐ์ ์ ์ฉํ ์ ๋ณด๊ฐ ๋ง์ด ํฌํจ๋์ด ์์ง๋ง, ์ด ๊ฐ์ด๋์์๋ audio
์ transcription
์ ์ด์ ์ ๋ง์ถ ๊ฒ์
๋๋ค. ๋ค๋ฅธ ์ด์ [~datasets.Dataset.remove_columns
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ ๊ฑฐํ์ธ์:
>>> minds = minds.remove_columns(["english_transcription", "intent_class", "lang_id"])
์์๋ฅผ ๋ค์ ํ๋ฒ ํ์ธํด๋ณด์ธ์:
>>> minds["train"][0]
{'audio': {'array': array([-0.00024414, 0. , 0. , ..., 0.00024414,
0.00024414, 0.00024414], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
'sampling_rate': 8000},
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"}
๋ ๊ฐ์ ํ๋๊ฐ ์์ต๋๋ค:
audio
: ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํ๊ธฐ ์ํด ํธ์ถํด์ผ ํ๋ ์์ฑ ์ ํธ์ 1์ฐจ์array(๋ฐฐ์ด)
transcription
: ๋ชฉํ ํ ์คํธ
์ ์ฒ๋ฆฌ[[preprocess]]
๋ค์์ผ๋ก ์ค๋์ค ์ ํธ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ Wav2Vec2 ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ต๋๋ค:
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
MInDS-14 ๋ฐ์ดํฐ ์ธํธ์ ์ํ๋ง ๋ ์ดํธ๋ 8000kHz์ด๋ฏ๋ก(๋ฐ์ดํฐ ์ธํธ ์นด๋์์ ํ์ธ), ์ฌ์ ํ๋ จ๋ Wav2Vec2 ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ๋ฐ์ดํฐ ์ธํธ๋ฅผ 16000kHz๋ก ๋ฆฌ์ํ๋งํด์ผ ํฉ๋๋ค:
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
>>> minds["train"][0]
{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ...,
2.78103951e-04, 2.38446111e-04, 1.18740834e-04], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
'sampling_rate': 16000},
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"}
์์ 'transcription'์์ ๋ณผ ์ ์๋ฏ์ด ํ ์คํธ๋ ๋๋ฌธ์์ ์๋ฌธ์๊ฐ ์์ฌ ์์ต๋๋ค. Wav2Vec2 ํ ํฌ๋์ด์ ๋ ๋๋ฌธ์ ๋ฌธ์์ ๋ํด์๋ง ํ๋ จ๋์ด ์์ผ๋ฏ๋ก ํ ์คํธ๊ฐ ํ ํฌ๋์ด์ ์ ์ดํ์ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค:
>>> def uppercase(example):
... return {"transcription": example["transcription"].upper()}
>>> minds = minds.map(uppercase)
์ด์ ๋ค์ ์์ ์ ์ํํ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ง๋ค์ด๋ณด๊ฒ ์ต๋๋ค:
audio
์ด์ ํธ์ถํ์ฌ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํฉ๋๋ค.- ์ค๋์ค ํ์ผ์์
input_values
๋ฅผ ์ถ์ถํ๊ณ ํ๋ก์ธ์๋กtranscription
์ด์ ํ ํฐํํฉ๋๋ค.
>>> def prepare_dataset(batch):
... audio = batch["audio"]
... batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"])
... batch["input_length"] = len(batch["input_values"][0])
... return batch
์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets [~datasets.Dataset.map
] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์. num_proc
๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ก์ธ์ค ์๋ฅผ ๋๋ฆฌ๋ฉด map
์ ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. [~datasets.Dataset.remove_columns
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ํ์ง ์์ ์ด์ ์ ๊ฑฐํ์ธ์:
>>> encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names["train"], num_proc=4)
๐ค Transformers์๋ ์๋ ์์ฑ ์ธ์์ฉ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๊ฐ ์์ผ๋ฏ๋ก ์์ ๋ฐฐ์น๋ฅผ ์์ฑํ๋ ค๋ฉด [DataCollatorWithPadding
]์ ์กฐ์ ํด์ผ ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ ํ
์คํธ์ ๋ ์ด๋ธ์ ๋ฐฐ์น์์ ๊ฐ์ฅ ๊ธด ์์์ ๊ธธ์ด์ ๋์ ์ผ๋ก ํจ๋ฉํ์ฌ ๊ธธ์ด๋ฅผ ๊ท ์ผํ๊ฒ ํฉ๋๋ค. tokenizer
ํจ์์์ padding=True
๋ฅผ ์ค์ ํ์ฌ ํ
์คํธ๋ฅผ ํจ๋ฉํ ์ ์์ง๋ง, ๋์ ํจ๋ฉ์ด ๋ ํจ์จ์ ์
๋๋ค.
๋ค๋ฅธ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ๋ฌ๋ฆฌ ์ด ํน์ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ input_values
์ labels
์ ๋ํด ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ ์ฉํด์ผ ํฉ๋๋ค.
>>> import torch
>>> from dataclasses import dataclass, field
>>> from typing import Any, Dict, List, Optional, Union
>>> @dataclass
... class DataCollatorCTCWithPadding:
... processor: AutoProcessor
... padding: Union[bool, str] = "longest"
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
... # ์
๋ ฅ๊ณผ ๋ ์ด๋ธ์ ๋ถํ ํฉ๋๋ค
... # ๊ธธ์ด๊ฐ ๋ค๋ฅด๊ณ , ๊ฐ๊ฐ ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ฌ์ฉํด์ผ ํ๊ธฐ ๋๋ฌธ์
๋๋ค
... input_features = [{"input_values": feature["input_values"][0]} for feature in features]
... label_features = [{"input_ids": feature["labels"]} for feature in features]
... batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")
... labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")
... # ํจ๋ฉ์ ๋ํด ์์ค์ ์ ์ฉํ์ง ์๋๋ก -100์ผ๋ก ๋์ฒดํฉ๋๋ค
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
... batch["labels"] = labels
... return batch
์ด์ DataCollatorForCTCWithPadding
์ ์ธ์คํด์คํํฉ๋๋ค:
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")
ํ๊ฐํ๊ธฐ[[evaluate]]
ํ๋ จ ์ค์ ํ๊ฐ ์งํ๋ฅผ ํฌํจํ๋ฉด ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๋ฐ ๋์์ด ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๐ค Evaluate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ํ๊ฐ ๋ฐฉ๋ฒ์ ๋น ๋ฅด๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ์ด ์์ ์์๋ ๋จ์ด ์ค๋ฅ์จ(Word Error Rate, WER) ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ต๋๋ค. (ํ๊ฐ ์งํ๋ฅผ ๋ถ๋ฌ์ค๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๐ค Evaluate ๋๋ฌ๋ณด๊ธฐ๋ฅผ ์ฐธ์กฐํ์ธ์):
>>> import evaluate
>>> wer = evaluate.load("wer")
๊ทธ๋ฐ ๋ค์ ์์ธก๊ฐ๊ณผ ๋ ์ด๋ธ์ [~evaluate.EvaluationModule.compute
]์ ์ ๋ฌํ์ฌ WER์ ๊ณ์ฐํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค:
>>> import numpy as np
>>> def compute_metrics(pred):
... pred_logits = pred.predictions
... pred_ids = np.argmax(pred_logits, axis=-1)
... pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
... pred_str = processor.batch_decode(pred_ids)
... label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
... wer = wer.compute(predictions=pred_str, references=label_str)
... return {"wer": wer}
์ด์ compute_metrics
ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ผ๋ฉฐ, ํ๋ จ์ ์ค์ ํ ๋ ์ด ํจ์๋ก ๋๋์์ฌ ๊ฒ์
๋๋ค.
ํ๋ จํ๊ธฐ[[train]]
[Trainer
]๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ด ์ต์ํ์ง ์๋ค๋ฉด, ์ฌ๊ธฐ์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํด๋ณด์ธ์!
์ด์ ๋ชจ๋ธ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! [AutoModelForCTC
]๋ก Wav2Vec2๋ฅผ ๊ฐ์ ธ์ค์ธ์. ctc_loss_reduction
๋งค๊ฐ๋ณ์๋ก CTC ์์ค์ ์ ์ฉํ ์ถ์(reduction) ๋ฐฉ๋ฒ์ ์ง์ ํ์ธ์. ๊ธฐ๋ณธ๊ฐ์ธ ํฉ๊ณ ๋์ ํ๊ท ์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ์ข์ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค:
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer
>>> model = AutoModelForCTC.from_pretrained(
... "facebook/wav2vec2-base",
... ctc_loss_reduction="mean",
... pad_token_id=processor.tokenizer.pad_token_id,
... )
์ด์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค:
- [
TrainingArguments
]์์ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์.output_dir
์ ๋ชจ๋ธ์ ์ ์ฅํ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๋ ์ ์ผํ ํ์ ๋งค๊ฐ๋ณ์์ ๋๋ค.push_to_hub=True
๋ฅผ ์ค์ ํ์ฌ ๋ชจ๋ธ์ Hub์ ์ ๋ก๋ ํ ์ ์์ต๋๋ค(๋ชจ๋ธ์ ์ ๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค). [Trainer
]๋ ๊ฐ ์ํญ๋ง๋ค WER์ ํ๊ฐํ๊ณ ํ๋ จ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. - ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ,
compute_metrics
ํจ์์ ํจ๊ป [Trainer
]์ ํ๋ จ ์ธ์๋ฅผ ์ ๋ฌํ์ธ์. - [
~Trainer.train
]์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์.
>>> training_args = TrainingArguments(
... output_dir="my_awesome_asr_mind_model",
... per_device_train_batch_size=8,
... gradient_accumulation_steps=2,
... learning_rate=1e-5,
... warmup_steps=500,
... max_steps=2000,
... gradient_checkpointing=True,
... fp16=True,
... group_by_length=True,
... evaluation_strategy="steps",
... per_device_eval_batch_size=8,
... save_steps=1000,
... eval_steps=1000,
... logging_steps=25,
... load_best_model_at_end=True,
... metric_for_best_model="wer",
... greater_is_better=False,
... push_to_hub=True,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=encoded_minds["train"],
... eval_dataset=encoded_minds["test"],
... tokenizer=processor.feature_extractor,
... data_collator=data_collator,
... compute_metrics=compute_metrics,
... )
>>> trainer.train()
ํ๋ จ์ด ์๋ฃ๋๋ฉด ๋ชจ๋๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [~transformers.Trainer.push_to_hub
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ Hub์ ๊ณต์ ํ์ธ์:
>>> trainer.push_to_hub()
์๋ ์์ฑ ์ธ์์ ์ํด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ ์์ธํ ์์ ๋ ์์ด ์๋ ์์ฑ ์ธ์์ ์ํ ๋ธ๋ก๊ทธ ํฌ์คํธ์ ๋ค๊ตญ์ด ์๋ ์์ฑ ์ธ์์ ์ํ ํฌ์คํธ๋ฅผ ์ฐธ์กฐํ์ธ์.
์ถ๋ก ํ๊ธฐ[[inference]]
์ข์์, ์ด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค!
์ถ๋ก ์ ์ฌ์ฉํ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค์ธ์. ํ์ํ ๊ฒฝ์ฐ ์ค๋์ค ํ์ผ์ ์ํ๋ง ๋น์จ์ ๋ชจ๋ธ์ ์ํ๋ง ๋ ์ดํธ์ ๋ง๊ฒ ๋ฆฌ์ํ๋งํ๋ ๊ฒ์ ์์ง ๋ง์ธ์!
>>> from datasets import load_dataset, Audio
>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train")
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> audio_file = dataset[0]["audio"]["path"]
์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ํํด๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [pipeline
]์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋ ์์ฑ ์ธ์์ ์ํ pipeline
์ ์ธ์คํด์คํํ๊ณ ์ค๋์ค ํ์ผ์ ์ ๋ฌํ์ธ์:
>>> from transformers import pipeline
>>> transcriber = pipeline("automatic-speech-recognition", model="stevhliu/my_awesome_asr_minds_model")
>>> transcriber(audio_file)
{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'}
ํ ์คํธ๋ก ๋ณํ๋ ๊ฒฐ๊ณผ๊ฐ ๊ฝค ๊ด์ฐฎ์ง๋ง ๋ ์ข์ ์๋ ์์ต๋๋ค! ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ ๋ง์ ์์ ๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์!
pipeline
์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ์ฌํํ ์๋ ์์ต๋๋ค:
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained("stevhliu/my_awesome_asr_mind_model")
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
์ ๋ ฅ์ ๋ชจ๋ธ์ ์ ๋ฌํ๊ณ ๋ก์ง์ ๋ฐํํ์ธ์:
>>> from transformers import AutoModelForCTC
>>> model = AutoModelForCTC.from_pretrained("stevhliu/my_awesome_asr_mind_model")
>>> with torch.no_grad():
... logits = model(**inputs).logits
๊ฐ์ฅ ๋์ ํ๋ฅ ์ input_ids
๋ฅผ ์์ธกํ๊ณ , ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์์ธก๋ input_ids
๋ฅผ ๋ค์ ํ
์คํธ๋ก ๋์ฝ๋ฉํ์ธ์:
>>> import torch
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription
['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER']