pytorch Fairseq自定义模型训练错误:使用简单LSTM架构运行fairseq-train的问题

2ic8powd  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(152)

我正在尝试使用Fairseq的**fairseq-train**命令训练一个自定义的序列到序列模型。我在Google Collab中实现了自己的SimpleLSTM架构,虽然Fairseq似乎可以正确检测模型,但它在训练过程中不断抛出错误。

数据准备命令:

!pip install fairseq

!git clone https://github.com/pytorch/fairseq.git
cd /content/fairseq/examples/translation
!chmod +x /content/fairseq/examples/translation/prepare-iwslt14.sh
!/content/fairseq/examples/translation/prepare-iwslt14.sh

训练命令:

!fairseq-train /content/fairseq/examples/translation/iwslt14.tokenized.de-en \
  --arch=tutorial_simple_lstm \
  --encoder-dropout=0.2 \
  --decoder-dropout=0.2 \
  --optimizer=adam \
  --lr=0.005 \
  --lr-shrink=0.5 \
  --max-tokens=12000

我的型号:

我在其他线程中看到,问题是由于没有将模型放在Fairseq models文件夹中,但我已经通过以下方式做到了这一点:

%%writefile /content/fairseq/fairseq/models/tutorial_simple_lstm.py

import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder

class SimpleLSTMEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.args = args
        self.embed_tokens = nn.Embedding(len(dictionary), embed_dim, padding_idx=dictionary.pad())
        self.dropout = nn.Dropout(p=dropout)
        self.lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=1, bidirectional=False, batch_first=True)

    def forward(self, src_tokens, src_lengths):
        if self.args.left_pad_source:
            src_tokens = utils.convert_padding_direction(src_tokens, padding_idx=self.dictionary.pad(), left_to_right=True)
        x = self.embed_tokens(src_tokens)
        x = self.dropout(x)
        x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
        _outputs, (final_hidden, _final_cell) = self.lstm(x)
        return {'final_hidden': final_hidden.squeeze(0)}

    def reorder_encoder_out(self, encoder_out, new_order):
        pass
import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder, FairseqDecoder
import torch

class SimpleLSTMEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.args = args
        self.embed_tokens = nn.Embedding(len(dictionary), embed_dim, padding_idx=dictionary.pad())
        self.dropout = nn.Dropout(p=dropout)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=1, bidirectional=False, batch_first=True)

    def forward(self, src_tokens, src_lengths):
        if self.args.left_pad_source:
            src_tokens = utils.convert_padding_direction(src_tokens, padding_idx=self.dictionary.pad(), left_to_right=True)
        x = self.embed_tokens(src_tokens)
        x = self.dropout(x)
        x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
        _outputs, (final_hidden, _final_cell) = self.lstm(x)
        return {'final_hidden': final_hidden.squeeze(0)}

    def reorder_encoder_out(self, encoder_out, new_order):
        final_hidden = encoder_out['final_hidden']
        return {'final_hidden': final_hidden.index_select(0, new_order)}

class SimpleLSTMDecoder(FairseqDecoder):
    def __init__(self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.embed_tokens = nn.Embedding(len(dictionary), embed_dim, padding_idx=dictionary.pad())
        self.dropout = nn.Dropout(p=dropout)
        self.lstm = nn.LSTM(encoder_hidden_dim + embed_dim, hidden_dim, num_layers=1, bidirectional=False)
        self.output_projection = nn.Linear(hidden_dim, len(dictionary))

    def forward(self, prev_output_tokens, encoder_out):
        bsz, tgt_len = prev_output_tokens.size()
        final_encoder_hidden = encoder_out['final_hidden']
        x = self.embed_tokens(prev_output_tokens)
        x = self.dropout(x)
        x = torch.cat([x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], dim=2)
        initial_state = (final_encoder_hidden.unsqueeze(0), torch.zeros_like(final_encoder_hidden).unsqueeze(0))
        output, _ = self.lstm(x.transpose(0, 1), initial_state)
        x = output.transpose(0, 1)
        x = self.output_projection(x)
        return x, None

from fairseq.models import FairseqEncoderDecoderModel, register_model

@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):

    @staticmethod
    def add_args(parser):
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N')
        parser.add_argument('--encoder-hidden-dim', type=int, metavar='N')
        parser.add_argument('--encoder-dropout', type=float, default=0.1)
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N')
        parser.add_argument('--decoder-hidden-dim', type=int, metavar='N')
        parser.add_argument('--decoder-dropout', type=float, default=0.1)

    @classmethod
    def build_model(cls, args, task):
        encoder = SimpleLSTMEncoder(
            args=args,
            dictionary=task.source_dictionary,
            embed_dim=args.encoder_embed_dim,
            hidden_dim=args.encoder_hidden_dim,
            dropout=args.encoder_dropout,
        )
        decoder = SimpleLSTMDecoder(
            dictionary=task.target_dictionary,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            hidden_dim=args.decoder_hidden_dim,
            dropout=args.decoder_dropout,
        )
        model = SimpleLSTMModel(encoder, decoder)
        print(model)
        return model

from fairseq.models import register_model_architecture

@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
    args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)

错误:

2023-09-16 11:37:17.444106: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-16 11:37:18.331492: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2023-09-16 11:37:19 | INFO | numexpr.utils | NumExpr defaulting to 2 threads.
2023-09-16 11:37:20 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX
usage: fairseq-train
       [-h]
       [--no-progress-bar]
       [--log-interval LOG_INTERVAL]
       [--log-format {json,none,simple,tqdm}]
       [--log-file LOG_FILE]
       [--aim-repo AIM_REPO]
       [--aim-run-hash AIM_RUN_HASH]
       [--tensorboard-logdir TENSORBOARD_LOGDIR]
       [--wandb-project WANDB_PROJECT]
       [--azureml-logging]
       [--seed SEED]
       [--cpu]
       [--tpu]
       [--bf16]
       [--memory-efficient-bf16]
       [--fp16]
       [--memory-efficient-fp16]
       [--fp16-no-flatten-grads]
       [--fp16-init-scale FP16_INIT_SCALE]
       [--fp16-scale-window FP16_SCALE_WINDOW]
       [--fp16-scale-tolerance FP16_SCALE_TOLERANCE]
       [--on-cpu-convert-precision]
       [--min-loss-scale MIN_LOSS_SCALE]
       [--threshold-loss-scale THRESHOLD_LOSS_SCALE]
       [--amp]
       [--amp-batch-retries AMP_BATCH_RETRIES]
       [--amp-init-scale AMP_INIT_SCALE]
       [--amp-scale-window AMP_SCALE_WINDOW]
       [--user-dir USER_DIR]
       [--empty-cache-freq EMPTY_CACHE_FREQ]
       [--all-gather-list-size ALL_GATHER_LIST_SIZE]
       [--model-parallel-size MODEL_PARALLEL_SIZE]
       [--quantization-config-path QUANTIZATION_CONFIG_PATH]
       [--profile]
       [--reset-logging]
       [--suppress-crashes]
       [--use-plasma-view]
       [--plasma-path PLASMA_PATH]
       [--criterion {adaptive_loss,composite_loss,cross_entropy,ctc,fastspeech2,hubert,label_smoothed_cross_entropy,latency_augmented_label_smoothed_cross_entropy,label_smoothed_cross_entropy_with_alignment,label_smoothed_cross_entropy_with_ctc,legacy_masked_lm_loss,masked_lm,model,nat_loss,sentence_prediction,sentence_prediction_adapters,sentence_ranking,tacotron2,speech_to_unit,speech_to_spectrogram,speech_unit_lm_criterion,wav2vec,vocab_parallel_cross_entropy}]
       [--tokenizer {moses,nltk,space}]
       [--bpe {byte_bpe,bytes,characters,fastbpe,gpt2,bert,hf_byte_bpe,sentencepiece,subword_nmt}]
       [--optimizer {adadelta,adafactor,adagrad,adam,adamax,composite,cpu_adam,lamb,nag,sgd}]
       [--lr-scheduler {cosine,fixed,inverse_sqrt,manual,pass_through,polynomial_decay,reduce_lr_on_plateau,step,tri_stage,triangular}]
       [--scoring {bert_score,sacrebleu,bleu,chrf,meteor,wer}]
       [--task TASK]
       [--num-workers NUM_WORKERS]
       [--skip-invalid-size-inputs-valid-test]
       [--max-tokens MAX_TOKENS]
       [--batch-size BATCH_SIZE]
       [--required-batch-size-multiple REQUIRED_BATCH_SIZE_MULTIPLE]
       [--required-seq-len-multiple REQUIRED_SEQ_LEN_MULTIPLE]
       [--dataset-impl {raw,lazy,cached,mmap,fasta,huffman}]
       [--data-buffer-size DATA_BUFFER_SIZE]
       [--train-subset TRAIN_SUBSET]
       [--valid-subset VALID_SUBSET]
       [--combine-valid-subsets]
       [--ignore-unused-valid-subsets]
       [--validate-interval VALIDATE_INTERVAL]
       [--validate-interval-updates VALIDATE_INTERVAL_UPDATES]
       [--validate-after-updates VALIDATE_AFTER_UPDATES]
       [--fixed-validation-seed FIXED_VALIDATION_SEED]
       [--disable-validation]
       [--max-tokens-valid MAX_TOKENS_VALID]
       [--batch-size-valid BATCH_SIZE_VALID]
       [--max-valid-steps MAX_VALID_STEPS]
       [--curriculum CURRICULUM]
       [--gen-subset GEN_SUBSET]
       [--num-shards NUM_SHARDS]
       [--shard-id SHARD_ID]
       [--grouped-shuffling]
       [--update-epoch-batch-itr UPDATE_EPOCH_BATCH_ITR]
       [--update-ordered-indices-seed]
       [--distributed-world-size DISTRIBUTED_WORLD_SIZE]
       [--distributed-num-procs DISTRIBUTED_NUM_PROCS]
       [--distributed-rank DISTRIBUTED_RANK]
       [--distributed-backend DISTRIBUTED_BACKEND]
       [--distributed-init-method DISTRIBUTED_INIT_METHOD]
       [--distributed-port DISTRIBUTED_PORT]
       [--device-id DEVICE_ID]
       [--distributed-no-spawn]
       [--ddp-backend {c10d,fully_sharded,legacy_ddp,no_c10d,pytorch_ddp,slowmo}]
       [--ddp-comm-hook {none,fp16}]
       [--bucket-cap-mb BUCKET_CAP_MB]
       [--fix-batches-to-gpus]
       [--find-unused-parameters]
       [--gradient-as-bucket-view]
       [--fast-stat-sync]
       [--heartbeat-timeout HEARTBEAT_TIMEOUT]
       [--broadcast-buffers]
       [--slowmo-momentum SLOWMO_MOMENTUM]
       [--slowmo-base-algorithm SLOWMO_BASE_ALGORITHM]
       [--localsgd-frequency LOCALSGD_FREQUENCY]
       [--nprocs-per-node NPROCS_PER_NODE]
       [--pipeline-model-parallel]
       [--pipeline-balance PIPELINE_BALANCE]
       [--pipeline-devices PIPELINE_DEVICES]
       [--pipeline-chunks PIPELINE_CHUNKS]
       [--pipeline-encoder-balance PIPELINE_ENCODER_BALANCE]
       [--pipeline-encoder-devices PIPELINE_ENCODER_DEVICES]
       [--pipeline-decoder-balance PIPELINE_DECODER_BALANCE]
       [--pipeline-decoder-devices PIPELINE_DECODER_DEVICES]
       [--pipeline-checkpoint {always,never,except_last}]
       [--zero-sharding {none,os}]
       [--no-reshard-after-forward]
       [--fp32-reduce-scatter]
       [--cpu-offload]
       [--use-sharded-state]
       [--not-fsdp-flatten-parameters]
       [--arch ARCH]
       [--max-epoch MAX_EPOCH]
       [--max-update MAX_UPDATE]
       [--stop-time-hours STOP_TIME_HOURS]
       [--clip-norm CLIP_NORM]
       [--sentence-avg]
       [--update-freq UPDATE_FREQ]
       [--lr LR]
       [--stop-min-lr STOP_MIN_LR]
       [--use-bmuf]
       [--skip-remainder-batch]
       [--save-dir SAVE_DIR]
       [--restore-file RESTORE_FILE]
       [--continue-once CONTINUE_ONCE]
       [--finetune-from-model FINETUNE_FROM_MODEL]
       [--reset-dataloader]
       [--reset-lr-scheduler]
       [--reset-meters]
       [--reset-optimizer]
       [--optimizer-overrides OPTIMIZER_OVERRIDES]
       [--save-interval SAVE_INTERVAL]
       [--save-interval-updates SAVE_INTERVAL_UPDATES]
       [--keep-interval-updates KEEP_INTERVAL_UPDATES]
       [--keep-interval-updates-pattern KEEP_INTERVAL_UPDATES_PATTERN]
       [--keep-last-epochs KEEP_LAST_EPOCHS]
       [--keep-best-checkpoints KEEP_BEST_CHECKPOINTS]
       [--no-save]
       [--no-epoch-checkpoints]
       [--no-last-checkpoints]
       [--no-save-optimizer-state]
       [--best-checkpoint-metric BEST_CHECKPOINT_METRIC]
       [--maximize-best-checkpoint-metric]
       [--patience PATIENCE]
       [--checkpoint-suffix CHECKPOINT_SUFFIX]
       [--checkpoint-shard-count CHECKPOINT_SHARD_COUNT]
       [--load-checkpoint-on-all-dp-ranks]
       [--write-checkpoints-asynchronously]
       [--store-ema]
       [--ema-decay EMA_DECAY]
       [--ema-start-update EMA_START_UPDATE]
       [--ema-seed-model EMA_SEED_MODEL]
       [--ema-update-freq EMA_UPDATE_FREQ]
       [--ema-fp32]
fairseq-train: error: argument --arch/-a: invalid choice: 'tutorial_simple_lstm' (choose from 's2t_berard', 's2t_berard_256_3_3', 's2t_berard_512_3_2', 's2t_berard_512_5_3', 'transformer_tiny', 'transformer', 'transformer_iwslt_de_en', 'transformer_wmt_en_de', 'transformer_vaswani_wmt_en_de_big', 'transformer_vaswani_wmt_en_fr_big', 'transformer_wmt_en_de_big', 'transformer_wmt_en_de_big_t2t', 'convtransformer', 'convtransformer_espnet', 's2t_transformer', 's2t_transformer_s', 's2t_transformer_xs', 's2t_transformer_sp', 's2t_transformer_m', 's2t_transformer_mp', 's2t_transformer_l', 's2t_transformer_lp', 'wav2vec', 'wav2vec2', 'wav2vec_ctc', 'wav2vec_seq2seq', 'xm_transformer', 's2t_conformer', 'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr', 'tacotron_2', 'tts_transformer', 'fastspeech2', 'lstm', 'lstm_wiseman_iwslt_de_en', 'lstm_luong_wmt_en_de', 'lstm_lm', 'fconv_lm', 'fconv_lm_dauphin_wikitext103', 'fconv_lm_dauphin_gbw', 'hubert', 'hubert_ctc', 'lightconv', 'lightconv_iwslt_de_en', 'lightconv_wmt_en_de', 'lightconv_wmt_en_de_big', 'lightconv_wmt_en_fr_big', 'lightconv_wmt_zh_en_big', 'lightconv_lm', 'lightconv_lm_gbw', 'fconv_self_att', 'fconv_self_att_wp', 'nonautoregressive_transformer', 'nonautoregressive_transformer_wmt_en_de', 'nacrf_transformer', 'iterative_nonautoregressive_transformer', 'iterative_nonautoregressive_transformer_wmt_en_de', 'cmlm_transformer', 'cmlm_transformer_wmt_en_de', 'levenshtein_transformer', 'levenshtein_transformer_wmt_en_de', 'levenshtein_transformer_vaswani_wmt_en_de_big', 'levenshtein_transformer_wmt_en_de_big', 'insertion_transformer', 'transformer_lm', 'transformer_lm_big', 'transformer_lm_baevski_wiki103', 'transformer_lm_wiki103', 'transformer_lm_baevski_gbw', 'transformer_lm_gbw', 'transformer_lm_gpt', 'transformer_lm_gpt2_small', 'transformer_lm_gpt2_tiny', 'transformer_lm_gpt2_medium', 'transformer_lm_gpt2_big', 'transformer_lm_gpt2_big_wide', 'transformer_lm_gpt2_bigger', 'transformer_lm_gpt3_small', 'transformer_lm_gpt3_medium', 'transformer_lm_gpt3_large', 'transformer_lm_gpt3_xl', 'transformer_lm_gpt3_2_7', 'transformer_lm_gpt3_6_7', 'transformer_lm_gpt3_13', 'transformer_lm_gpt3_175', 'transformer_ulm', 'transformer_ulm_big', 'transformer_ulm_tiny', 'roberta', 'roberta_prenorm', 'roberta_base', 'roberta_large', 'xlm', 'roberta_enc_dec', 'transformer_from_pretrained_xlm', 'transformer_align', 'transformer_wmt_en_de_big_align', 'xmod_base_13', 'xmod_base_30', 'xmod_base_60', 'xmod_base_75', 'xmod_base', 'xmod_large_prenorm', 'bart_large', 'bart_base', 'mbart_large', 'mbart_base', 'mbart_base_wmt20', 's2ut_transformer', 's2ut_transformer_fisher', 's2spect_transformer', 's2spect_transformer_fisher', 's2ut_conformer', 'masked_lm', 'bert_base', 'bert_large', 'xlm_base', 'multilingual_transformer', 'multilingual_transformer_iwslt_de_en', 'hf_gpt2', 'hf_gpt2_medium', 'hf_gpt2_large', 'hf_gpt2_xl', 'dummy_model', 'model_parallel_roberta', 'model_parallel_roberta_v1', 'model_parallel_roberta_postnorm', 'model_parallel_roberta_base', 'model_parallel_roberta_large', 'transformer_iwslt_de_en_pipeline_parallel', 'transformer_wmt_en_de_big_pipeline_parallel', 'transformer_lm_megatron', 'transformer_lm_megatron_11b')

我相信错误可能是在执行环境中,因为它是Google Colab,在安装的库中,或者在模型代码本身中,可能没有很好地实现。

x4shl7ld

x4shl7ld1#

解决方案:

!git clone https://github.com/pytorch/fairseq.git
%cd fairseq
!pip install --editable .

使存储库可编辑。

相关问题