论文参考文献

首页 » 常识 » 常识 » 万字详文超越BERT模型的ELECTRA
TUhjnbcbe - 2023/5/29 21:44:00

1、概述

在年11月份,NLP大神Manning联合谷歌做的ELECTRA一经发布,迅速火爆整个NLP圈,其中ELECTRA-small模型参数量仅为BERT-base模型的1/10,性能却依然能与BERT、RoBERTa等模型相媲美,得益于ELECTRA模型的巧妙构思LOSS,在年3月份Google对代码做了开源,下面针对Google放出的ELECTRA做代码做解读,希望通过此文章大家能在自己文本数据、行为序列数据训练一个较好的预训练模型,在业务上提升价值。

2、ELECTRA模型

2、1总体框架

ELECTRA模型(BASE版本)本质是换一种方法来训练BERT模型的参数;BERT模型主要是利用MLM的思想来训练参数,直接把需要预测的词给挖掉了,挖了15%的比例。由于每次训练是一段话中15%的token,导致模型收敛更新较慢,需要的语料也比较庞大。同时为了兼顾处理阅读理解这样的任务,模型加入了NSP,是个二分类任务,判断上下两句是不是互为上下句;而ELECTRA模型主要借助于图像领域gan的思想,利用生成器和判别器思想,如下图所;ELECTRA的预训练可以分为两部分,生成器部分仍然是MLM,结构与BERT类似,利用这个模型对挖掉的15%的词进行预测,并将其进行替换,若替换的词不是原词,则打上被替换的标签,语句的其他词则打上没有替换的标签,判别器部分是训练一个判别模型对所有位置的词进行替换识别,此时预测模型转换成了一个二分类模型。这个转换可以带来效率的提升,对所有位置的词进行预测,收敛速度会快的多,损失函数是利用生成器部分的损失和判别器的损失函数以一个比例数(官方代码是50)相加。

2、2代码框架

年3月份Google开源了ELECTRA模型代码,见代码连接,其主要代码框架如下:

下面对该代码框架做一一说明:

finetune:该文件夹下面的代码主要是对已经训练好的ELECTRA模型做微调的代码例子,如文本分类、NER识别、阅读理解等任务,这个和BERT的任务一致,在这里不做过多累赘。

model/modeling.py:该文件主要是bert模型的实现逻辑以及bert模型的配置读取代码,在ELECTRA模型的预训练阶段生成和判别阶段做调用,另外在做各种finetuneing任务会调用。

model/optimization.py:该文件主要是对优化器的实现,主要是对AdamWeightDecay的实现,可以自己加lamb等优化方法的实现。

model/tokenization.py:该文件主要是WordPiece分词器的实现,可以对英文、中文分词,在将文本转化为tfrecord的时候会用的到。

pretrain/pretrain_data.py:该文件的主要作用是对ELECTRA模型在pretraining对tfrecords文件读取、collections.namedtuple更新上的一些逻辑实现。

pretrain_helpers.py:该文件是pretraining阶段核心功能实现,主要实现逻辑对序列做动态随机mask以及对已经mask的序列做unmask。

util/training_utils.py:该文件主要在训练阶段实现了一个Hook,在训练阶段为了打印更多日志信息。

util/utils.py:该文件主要是一些基础方法,如增删文件夹、序列化反序列化、读取配置文件等。

build_openwebtext_pretraining_dataset.py、build_pretraining_dataset.py:这两个文件功能类似,但是用到的数据源不一样,主要是把文本文件转化为tfrecord文件,tfrecord文件的key包括input_ids、input_mask、segment_ids,生成的tfrecord文件不像bert预训练需要的文件那样,不需要再生成masked_lm_positions,masked_lm_ids,masked_lm_weights,这几个key会在模型pretraining阶段自动生成,与此同时mask也是随机动态的,类似于RoBerta,不像BERT那样固定。里面写了个多线程加速;对于大型的文件,还是使用spark将文本转化较适宜。

configure_finetuning.py:finetuneing阶段的一些超参数配置,google这次放出的代码参数并没有使用tf.flags。

configure_pretraining.py:pretraining阶段的一些超参数配置,google这次放出的代码参数并没有使用tf.flags。

run_finetuning.py:模型做finetuning逻辑,加载已经训练好的ELECTRA模型做微调。

run_pretraining.py:模型pretraining逻辑,也是ELECTRA模型最核心的逻辑,下面会加以详细说明。

2、3pretraining阶段

ELECTRA模型pretraining阶段是最核心的逻辑,代码是在run_pretraining.py里面,下面会加以详细说明,整体阶段理解绘制了一张图,逻辑见如下图:

2.3.1主方法入口

在run_pretraining.py文件中主方法有三个必须的参数见:

defmain():parser=argparse.ArgumentParser(description=__doc__)parser.add_argument(--data-dir,required=True,help=Locationofdatafiles(modelweights,etc).)parser.add_argument(--model-name,required=True,help=Thenameofthemodelbeingfine-tuned.)parser.add_argument(--hparams,default={},help=JSONdictofmodelhyperparameters.)args=parser.parse_args()ifargs.hparams.endswith(.json):hparams=utils.load_json(args.hparams)else:hparams=json.loads(args.hparams)tf.logging.set_verbosity(tf.logging.ERROR)train_or_eval(configure_pretraining.PretrainingConfig(args.model_name,args.data_dir,**hparams))三个必须参数见如下:--data-dir:表示tfrecord文件的地址,一般是以pretrain_data.tfrecord-0-of*这种格式,调用build_pretraining_dataset.py文件生成,默认生成个tfrecord文件,数目可以自己改,此外需要注意的是切词需要制定vocab.txt,训练中文的模型词典指定BERT模型那个vocab.txt词典即可,同理用于英文的模型训练。

--model-name:表示预训练模型的名字一般是electar,可以自己设定。

--hparams:,一般是个json文件,可以传递自己的参数进去,比如你要训练的模型是small、base、big等模型,还有vocab_size,一般中文是,英文是。还有模型训练是否是测试状态等参数,一般我训练中文模型hparams参数是的config.json是:

{model_size:base,vocab_size:}详细的参数可以去看configure_pretraining.py,一般你传进去的参数进去会更新里面的超参数。

程序入口训练模型:

train_or_eval(configure_pretraining.PretrainingConfig(args.model_name,args.data_dir,**hparams))还有一个入口,只see.run()一次,用于测试,见如下:

train_one_step(configure_pretraining.PretrainingConfig(args.model_name,args.data_dir,**hparams))2.3.2数据mask

训练模型主要是代码是PretrainingModel类的定义,在PretrainingModel里面程序首先对输入的tfrecord文件做随机mask,

#Masktheinputmasked_inputs=pretrain_helpers.mask(config,pretrain_data.features_to_inputs(features),config.mask_prob)用于生成含有masked_lm_positions,masked_lm_ids,masked_lm_weights等key的tfrecord文件,随机MASK实现的主要逻辑是调用pretrain_helpers.mask()来实现,其中用到了随机生成多项分布的函数tf.random.categorical,这个函数目的是随机获取masked_lm_positions、masked_lm_weights,再根据masked_lm_positions调用tf.gather_nd做索引截取来获取masked_lm_ids。

2.3.3GeneratorBERT

数据获取之后往下一步走就是生成GeneratorBERT阶段的模型,调用方法见如下:

generator=self._build_transformer(masked_inputs,is_training,bert_config=get_generator_config(config,self._bert_config),embedding_size=(Noneifconfig.untied_generator_embeddingselseembedding_size),untied_embeddings=config.untied_generator_embeddings,name=generator)这里主要用于Generator阶段BERT模型生成,同时生成MLMloss和Fakedata,其中Fakedata非常核心。MLMloss生成见代码,和BERT的逻辑几乎一样:

def_get_masked_lm_output(self,inputs:pretrain_data.Inputs,model):Maskedlanguagemodelingsoftmaxlayer.masked_lm_weights=inputs.masked_lm_weightswithtf.variable_scope(generator_predictions):ifself._config.uniform_generator:logits=tf.zeros(self._bert_config.vocab_size)logits_tiled=tf.zeros(modeling.get_shape_list(inputs.masked_lm_ids)+[self._bert_config.vocab_size])logits_tiled+=tf.reshape(logits,[1,1,self._bert_config.vocab_size])logits=logits_tiledelse:relevant_hidden=pretrain_helpers.gather_positions(model.get_sequence_output(),inputs.masked_lm_positions)hidden=tf.layers.dense(relevant_hidden,units=modeling.get_shape_list(model.get_embedding_table())[-1],activation=modeling.get_activation(self._bert_config.hidden_act),kernel_initializer=modeling.create_initializer(self._bert_config.initializer_range))hidden=modeling.layer_norm(hidden)output_bias=tf.get_variable(output_bias,shape=[self._bert_config.vocab_size],initializer=tf.zeros_initializer())logits=tf.matmul(hidden,model.get_embedding_table(),transpose_b=True)logits=tf.nn.bias_add(logits,output_bias)oh_labels=tf.one_hot(inputs.masked_lm_ids,depth=self._bert_config.vocab_size,dtype=tf.float32)probs=tf.nn.softmax(logits)log_probs=tf.nn.log_softmax(logits)label_log_probs=-tf.reduce_sum(log_probs*oh_labels,axis=-1)numerator=tf.reduce_sum(inputs.masked_lm_weights*label_log_probs)denominator=tf.reduce_sum(masked_lm_weights)+1e-6loss=numerator/denominatorpreds=tf.argmax(log_probs,axis=-1,output_type=tf.int32)MLMOutput=collections.namedtuple(MLMOutput,[logits,probs,loss,per_example_loss,preds])returnMLMOutput(logits=logits,probs=probs,per_example_loss=label_log_probs,loss=loss,preds=preds)Fakedata数据生成逻辑见下面代码,这里调用了unmask函数和上面提到的mask函数作用相反,把原来input_ids随机mask的函数还原回去生成一个input_ids_new,再利用谷生成模型生成的logit取最大索引去还原原来被mask调的input_ids,生成一个updated_input_ids,判断input_ids_new和updated_input_ids是否相等,生成truelabel

def_get_fake_data(self,inputs,mlm_logits):Samplefromthegeneratortocreatecorruptedinput.inputs=pretrain_helpers.unmask(inputs)disallow=tf.one_hot(inputs.masked_lm_ids,depth=self._bert_config.vocab_size,dtype=tf.float32)ifself._config.disallow_correctelseNonesampled_tokens=tf.stop_gradient(pretrain_helpers.sample_from_softmax(mlm_logits/self._config.temperature,disallow=disallow))sampled_tokids=tf.argmax(sampled_tokens,-1,output_type=tf.int32)updated_input_ids,masked=pretrain_helpers.scatter_update(inputs.input_ids,sampled_tokids,inputs.masked_lm_positions)labels=masked*(1-tf.cast(tf.equal(updated_input_ids,inputs.input_ids),tf.int32))updated_inputs=pretrain_data.get_updated_inputs(inputs,input_ids=updated_input_ids)FakedData=collections.namedtuple(FakedData,[inputs,is_fake_tokens,sampled_tokens])returnFakedData(inputs=updated_inputs,is_fake_tokens=labels,sampled_tokens=sampled_tokens)2.3.4DiscriminaBERT

利用上一步生成的Fakedata,作为DiscriminaBERT的输入,见代码:

ifconfig.electra_objective:discriminator=self._build_transformer(fake_data.inputs,is_training,reuse=notconfig.untied_generator,embedding_size=embedding_size)disc_output=self._get_discriminator_output(fake_data.inputs,discriminator,fake_data.is_fake_tokens)获取二分类的损失函数,代码见:

def_get_discriminator_output(self,inputs,discriminator,labels):Discriminatorbinaryclassifier.withtf.variable_scope(discriminator_predictions):hidden=tf.layers.dense(discriminator.get_sequence_output(),units=self._bert_config.hidden_size,activation=modeling.get_activation(self._bert_config.hidden_act),kernel_initializer=modeling.create_initializer(self._bert_config.initializer_range))logits=tf.squeeze(tf.layers.dense(hidden,units=1),-1)weights=tf.cast(inputs.input_mask,tf.float32)labelsf=tf.cast(labels,tf.float32)losses=tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,labels=labelsf)*weightsper_example_loss=(tf.reduce_sum(losses,axis=-1)/(1e-6+tf.reduce_sum(weights,axis=-1)))loss=tf.reduce_sum(losses)/(1e-6+tf.reduce_sum(weights))probs=tf.nn.sigmoid(logits)preds=tf.cast(tf.round((tf.sign(logits)+1)/2),tf.int32)DiscOutput=collections.namedtuple(DiscOutput,[loss,per_example_loss,probs,preds,labels])2.3.5总的损失函数

上面一步骤求出了disc_output.loss也就是sigmod的loss,代码见:

self.total_loss=config.gen_weight*mlm_output.lossself.total_loss+=config.disc_weight*disc_output.loss这里config.gen_weight=1以及config.disc_weight=50,这里sigmod的损失函数设置为50,作者也没给明确的答复。

2.3.6模型优化以及checkpoint

上面一步已经求出了总的损失函数,下一步则是做模型优化训练以及做checkpoint,程序入口在train_or_eval()里面,代码见:

is_per_host=tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2tpu_cluster_resolver=Noneifconfig.use_tpuandconfig.tpu_name:tpu_cluster_resolver=tf.distribute.cluster_resolver.TPUClusterResolver(config.tpu_name,zone=config.tpu_zone,project=config.gcp_project)tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=config.iterations_per_loop,num_shards=(config.num_tpu_coresifconfig.do_trainelseconfig.num_tpu_cores),tpu_job_name=config.tpu_job_name,per_host_input_for_training=is_per_host)run_config=tf.estimator.tpu.RunConfig(cluster=tpu_cluster_resolver,model_dir=config.model_dir,save_checkpoints_steps=config.save_checkpoints_steps,tpu_config=tpu_config)model_fn=model_fn_builder(config=config)estimator=tf.estimator.tpu.TPUEstimator(use_tpu=config.use_tpu,model_fn=model_fn,config=run_config,train_batch_size=config.train_batch_size,eval_batch_size=config.eval_batch_size)可以看到google开源的代码主要是TPU的一些钩子,改成GPU也比较简单,在BERT里面就有GPU相关的钩子,下面就会讲到。

2、4finetuning阶段

ELECTRAfinetuning阶段给出了不少例子,也比较简单,在finetuneing文件下面,这里不做过多的说明,和bert类似,唯一要改的就是把TPU相关的设置改为GPU、CPU即可。

2、5序列训练改进

上面代码主要存在两个问题,第一个是TPU设置的问题,并不是人人都是土豪,还是要适配GPU的训练,第二个就是假如我想训练一个vocab比较大的序列模型,上面模型是训练不动的,loss方面改为负采样的形式。

2.5.1TPU改GPU训练

global_step=tf.train.get_or_create_global_step()optimizer=optimization.AdamWeightDecayOptimizer(learning_rate=learning_rate)train_op=optimizer.apply_gradients(zip(grads,tvars),global_step)update_global_step=tf.assign(global_step,global_step+1,name=update_global_step)output_spec=tf.estimator.EstimatorSpec(mode=mode,predictions=probabilities,loss=total_loss,train_op=tf.group(train_op,update_global_step))run_config=tf.estimator.RunConfig(model_dir=FLAGS.modelpath,save_checkpoints_steps=)bert_config=modeling.BertConfig.from_json_file(FLAGS.bert_config_file)model_fn=model_fn_builder(bert_config=bert_config,num_labels=FLAGS.n_class,is_training=True,init_checkpoint=FLAGS.init_checkpoint,learning_rate=FLAGS.learning_rate,use_one_hot_embeddings=False)estimator=tf.estimator.Estimator(model_fn=model_fn,config=run_config)total_files=glob.glob(/tf*)random.shuffle(total_files)eval_files=total_files.pop()input_fn_train=lambda:input_fn(total_files,FLAGS.batch_size,num_epochs=N)input_fn_eval=lambda:input_fn(eval_files,FLAGS.batch_size,is_training=False)train_spec=tf.estimator.TrainSpec(input_fn=input_fn_train,max_steps=0)eval_spec=tf.estimator.EvalSpec(input_fn=input_fn_eval,steps=None,start_delay_secs=30,throttle_secs=30)tf.estimator.train_and_evaluate(estimator,train_spec,eval_spec)2.5.2负采样改造

主要是对mlmloss做改造:

defget_masked_lm_output(bert_config,input_tensor,output_weights,positions,label_ids,label_weights):withtf.variable_scope(cls/predictions):withtf.variable_scope(transform):input_tensor=tf.layers.dense(input_tensor,units=bert_config.hidden_size,activation=modeling.get_activation(bert_config.hidden_act),kernel_initializer=modeling.create_initializer(bert_config.initializer_range))input_tensor=modeling.layer_norm(input_tensor)#batch*10*embeding_size#Theoutputweightsarethesameastheinputembeddings,butthereis#anoutput-onlybiasforeachtoken.output_bias=tf.get_variable(output_bias,shape=[bert_config.vocab_size],initializer=tf.zeros_initializer())label_ids=tf.reshape(label_ids,[-1,1])label_weights=tf.reshape(label_weights,[-1])per_example_loss=tf.nn.sampled_softmax_loss(weights=output_weights,biases=output_bias,labels=label_ids,inputs=input_tensor,num_sampled=N,num_classes=bert_config.vocab_size)numerator=tf.reduce_sum(label_weights*per_example_loss)denominator=tf.reduce_sum(label_weights)+1e-5loss=numerator/denominatorreturn(loss,per_example_loss)3、总结

前段时间挺忙,有比较多的新idea出来没来的及看,上周末花了一天时间看了下electra源码,并做记录,也看到不少团队做了一些中文electra预训练模型,虽然electra没有达到stateoftheart,和roberta差距可以忽视,但是这种训练方式这是一个很棒的idea,其收敛速度是其他以bert为基础为改造的模型不能比的,在序列建模就有非常重大的研究意义,欢迎一起交流。

4、参考文献

[1]

1
查看完整版本: 万字详文超越BERT模型的ELECTRA