如何在PyTorch中修复GPU内存不足

g52tjvyc  于 2023-05-29  发布在  其他
关注(0)|答案(1)|浏览(158)

我想为persian训练wav2vec2模型,我有2 h(7 k记录),我使用此代码进行训练

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

model.freeze_feature_extractor()

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/model-output",
    group_by_length=True,
    per_device_train_batch_size=4,
    evaluation_strategy="steps",
    num_train_epochs=30,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

trainer.train()

当我运行这个时,我得到了这个错误
CUDA内存不足。尝试分配1.22 GiB(GPU 0; 14.75 GiB总容量; 12.59 GiB已分配; 296.81 MiB可用; PyTorch总共保留了13.45 GiB)如果保留内存>>分配内存,请尝试设置max_split_size_mb以避免碎片。有关内存管理和PYTORCH_CUDA_ALLOC_CONF,请参见文档
这是ERROR x1c 0d1x

/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/processing_wav2vec2.py:155: UserWarning: `as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your labels by using the argument `text` of the regular `__call__` method (either in the same call as your audio inputs, or in a separate call.
warnings.warn(
[ 501/43430 01:23 < 1:59:11, 6.00 it/s, Epoch 0.12/10]
Step    Training Loss   Validation Loss
[50/61 00:24 < 00:05, 2.04 it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 1>:1                                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1664 in train                    │
│                                                                                                  │
│   1661 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1662 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1663 │   │   )                                                                                 │
│ ❱ 1664 │   │   return inner_training_loop(                                                       │
│   1665 │   │   │   args=args,                                                                    │
│   1666 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1667 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2019 in _inner_training_loop     │
│                                                                                                  │
│   2016 │   │   │   │   │   self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epo  │
│   2017 │   │   │   │   │   self.control = self.callback_handler.on_step_end(args, self.state, s  │
│   2018 │   │   │   │   │                                                                         │
│ ❱ 2019 │   │   │   │   │   self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_k  │
│   2020 │   │   │   │   else:                                                                     │
│   2021 │   │   │   │   │   self.control = self.callback_handler.on_substep_end(args, self.state  │
│   2022                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2300 in _maybe_log_save_evaluate │
│                                                                                                  │
│   2297 │   │   │   │   │   )                                                                     │
│   2298 │   │   │   │   │   metrics.update(dataset_metrics)                                       │
│   2299 │   │   │   else:                                                                         │
│ ❱ 2300 │   │   │   │   metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)                 │
│   2301 │   │   │   self._report_to_hp_search(trial, self.state.global_step, metrics)             │
│   2302 │   │   │                                                                                 │
│   2303 │   │   │   # Run delayed LR scheduler now that metrics are populated                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3029 in evaluate                 │
│                                                                                                  │
│   3026 │   │   start_time = time.time()                                                          │
│   3027 │   │                                                                                     │
│   3028 │   │   eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else se  │
│ ❱ 3029 │   │   output = eval_loop(                                                               │
│   3030 │   │   │   eval_dataloader,                                                              │
│   3031 │   │   │   description="Evaluation",                                                     │
│   3032 │   │   │   # No point gathering the predictions if there are no metrics, otherwise we d  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3210 in evaluation_loop          │
│                                                                                                  │
│   3207 │   │   │   │   │   batch_size = observed_batch_size                                      │
│   3208 │   │   │                                                                                 │
│   3209 │   │   │   # Prediction step                                                             │
│ ❱ 3210 │   │   │   loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_o  │
│   3211 │   │   │   inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inp  │
│   3212 │   │   │                                                                                 │
│   3213 │   │   │   if is_torch_tpu_available():                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3466 in prediction_step          │
│                                                                                                  │
│   3463 │   │   │   else:                                                                         │
│   3464 │   │   │   │   if has_labels or loss_without_labels:                                     │
│   3465 │   │   │   │   │   with self.compute_loss_context_manager():                             │
│ ❱ 3466 │   │   │   │   │   │   loss, outputs = self.compute_loss(model, inputs, return_outputs=  │
│   3467 │   │   │   │   │   loss = loss.mean().detach()                                           │
│   3468 │   │   │   │   │                                                                         │
│   3469 │   │   │   │   │   if isinstance(outputs, dict):                                         │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2767 in compute_loss             │
│                                                                                                  │
│   2764 │   │   │   labels = inputs.pop("labels")                                                 │
│   2765 │   │   else:                                                                             │
│   2766 │   │   │   labels = None                                                                 │
│ ❱ 2767 │   │   outputs = model(**inputs)                                                         │
│   2768 │   │   # Save past state if it exists                                                    │
│   2769 │   │   # TODO: this needs to be fixed and made cleaner later.                            │
│   2770 │   │   if self.args.past_index >= 0:                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1684   │
│ in forward                                                                                       │
│                                                                                                  │
│   1681 │   │                                                                                     │
│   1682 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│   1683 │   │                                                                                     │
│ ❱ 1684 │   │   outputs = self.wav2vec2(                                                          │
│   1685 │   │   │   input_values,                                                                 │
│   1686 │   │   │   attention_mask=attention_mask,                                                │
│   1687 │   │   │   output_attentions=output_attentions,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1320   │
│ in forward                                                                                       │
│                                                                                                  │
│   1317 │   │   │   hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention  │
│   1318 │   │   )                                                                                 │
│   1319 │   │                                                                                     │
│ ❱ 1320 │   │   encoder_outputs = self.encoder(                                                   │
│   1321 │   │   │   hidden_states,                                                                │
│   1322 │   │   │   attention_mask=attention_mask,                                                │
│   1323 │   │   │   output_attentions=output_attentions,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:798 in │
│ forward                                                                                          │
│                                                                                                  │
│    795 │   │   │   │   │   │   attention_mask,                                                   │
│    796 │   │   │   │   │   )                                                                     │
│    797 │   │   │   │   else:                                                                     │
│ ❱  798 │   │   │   │   │   layer_outputs = layer(                                                │
│    799 │   │   │   │   │   │   hidden_states, attention_mask=attention_mask, output_attentions=  │
│    800 │   │   │   │   │   )                                                                     │
│    801 │   │   │   │   hidden_states = layer_outputs[0]                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:679 in │
│ forward                                                                                          │
│                                                                                                  │
│    676 │                                                                                         │
│    677 │   def forward(self, hidden_states, attention_mask=None, output_attentions=False):       │
│    678 │   │   attn_residual = hidden_states                                                     │
│ ❱  679 │   │   hidden_states, attn_weights, _ = self.attention(                                  │
│    680 │   │   │   hidden_states, attention_mask=attention_mask, output_attentions=output_atten  │
│    681 │   │   )                                                                                 │
│    682 │   │   hidden_states = self.dropout(hidden_states)                                       │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:596 in │
│ forward                                                                                          │
│                                                                                                  │
│    593 │   │   │   attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + at  │
│    594 │   │   │   attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)      │
│    595 │   │                                                                                     │
│ ❱  596 │   │   attn_weights = nn.functional.softmax(attn_weights, dim=-1)                        │
│    597 │   │                                                                                     │
│    598 │   │   if layer_head_mask is not None:                                                   │
│    599 │   │   │   if layer_head_mask.size() != (self.num_heads,):                               │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:1843 in softmax                   │
│                                                                                                  │
│   1840 │   if dim is None:                                                                       │
│   1841 │   │   dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)                       │
│   1842 │   if dtype is None:                                                                     │
│ ❱ 1843 │   │   ret = input.softmax(dim)                                                          │
│   1844 │   else:                                                                                 │
│   1845 │   │   ret = input.softmax(dim, dtype=dtype)                                             │
│   1846 │   return ret                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 14.75 GiB total capacity; 5.06 GiB already
allocated; 5.60 GiB free; 8.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try 
setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF

相关问题