Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Union | |
import torch | |
from packaging import version | |
from torch import nn | |
from transformers import ( | |
Trainer, | |
is_apex_available, | |
) | |
if is_apex_available(): | |
from apex import amp | |
if version.parse(torch.__version__) >= version.parse("1.6"): | |
_is_native_amp_available = True | |
from torch.cuda.amp import autocast | |
class CTCTrainer(Trainer): | |
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | |
""" | |
Perform a training step on a batch of inputs. | |
Subclass and override to inject custom behavior. | |
Args: | |
model (:obj:`nn.Module`): | |
The model to train. | |
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): | |
The inputs and targets of the model. | |
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
argument :obj:`labels`. Check your model's documentation for all accepted arguments. | |
Return: | |
:obj:`torch.Tensor`: The tensor with training loss on this batch. | |
""" | |
model.train() | |
inputs = self._prepare_inputs(inputs) | |
if self.use_amp: | |
with autocast(): | |
loss = self.compute_loss(model, inputs) | |
else: | |
loss = self.compute_loss(model, inputs) | |
if self.args.gradient_accumulation_steps > 1: | |
loss = loss / self.args.gradient_accumulation_steps | |
if self.use_amp: | |
self.scaler.scale(loss).backward() | |
elif self.use_apex: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
elif self.deepspeed: | |
self.deepspeed.backward(loss) | |
else: | |
loss.backward() | |
return loss.detach() | |