Fine Tuning Approach
There are multiple approaches to fine-tune BERT for the target tasks.
- Further Pre-training the base BERT model
- Custom classification layer(s) on top of the base BERT model being trainable
- Custom classification layer(s) on top of the base BERT model being non-trainable (frozen)
Note that the BERT base model has been pre-trained only for two tasks as in the original paper.
3.1 Pre-training BERT ...we pre-train BERT using two unsupervised tasks
- Task #1: Masked LM
- Task #2: Next Sentence Prediction (NSP)
Hence, the base BERT model is like half-baked which can be fully baked for the target domain (1st way). We can use it as part of our custom model training with the base trainable (2nd) or not-trainable (3rd).
1st approach
How to Fine-Tune BERT for Text Classification? demonstrated the 1st approach of Further Pre-training, and pointed out the learning rate is the key to avoid Catastrophic Forgetting where the pre-trained knowledge is erased during learning of new knowledge.
We find that a lower learning rate, such as 2e-5,
is necessary to make BERT overcome the catastrophic forgetting problem. With an aggressive learn rate of 4e-4, the training set fails to converge.
Probably this is the reason why the BERT paper used 5e-5, 4e-5, 3e-5, and 2e-5 for fine-tuning.
We use a batch size of 32 and fine-tune for 3 epochs over the data for all GLUE tasks. For each task, we selected the best fine-tuning learning rate (among 5e-5, 4e-5, 3e-5, and 2e-5) on the Dev set
Note that the base model pre-training itself used higher learning rate.
The model was trained on 4 cloud TPUs in Pod configuration (16 TPU chips total) for one million steps with a batch size of 256. The sequence length was limited to 128 tokens for 90% of the steps and 512 for the remaining 10%. The optimizer used is Adam with a learning rate of 1e-4
, β1=0.9
and β2=0.999
, a weight decay of 0.01
, learning rate warmup for 10,000 steps and linear decay of the learning rate after.
Will describe the 1st way as part of the 3rd approach below.
FYI:
TFDistilBertModel is the bare base model with the name distilbert
.
Model: "tf_distil_bert_model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
=================================================================
Total params: 66,362,880
Trainable params: 66,362,880
Non-trainable params: 0
2nd approach
Huggingface takes the 2nd approach as in Fine-tuning with native PyTorch/TensorFlow where TFDistilBertForSequenceClassification
has added the custom classification layer classifier
on top of the base distilbert
model being trainable. The small learning rate requirement will apply as well to avoid the catastrophic forgetting.
from transformers import TFDistilBertForSequenceClassification
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss) # can also use any keras loss fn
model.fit(train_dataset.shuffle(1000).batch(16), epochs=3, batch_size=16)
Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
_________________________________________________________________
pre_classifier (Dense) multiple 590592
_________________________________________________________________
classifier (Dense) multiple 1538
_________________________________________________________________
dropout_59 (Dropout) multiple 0
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010 <--- All parameters are trainable
Non-trainable params: 0
Implementation of the 2nd approach
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from transformers import (
DistilBertTokenizerFast,
TFDistilBertForSequenceClassification,
)
DATA_COLUMN = 'text'
LABEL_COLUMN = 'category_index'
MAX_SEQUENCE_LENGTH = 512
LEARNING_RATE = 5e-5
BATCH_SIZE = 16
NUM_EPOCHS = 3
# --------------------------------------------------------------------------------
# Tokenizer
# --------------------------------------------------------------------------------
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
def tokenize(sentences, max_length=MAX_SEQUENCE_LENGTH, padding='max_length'):
"""Tokenize using the Huggingface tokenizer
Args:
sentences: String or list of string to tokenize
padding: Padding method ['do_not_pad'|'longest'|'max_length']
"""
return tokenizer(
sentences,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="tf"
)
# --------------------------------------------------------------------------------
# Load data
# --------------------------------------------------------------------------------
raw_train = pd.read_csv("./train.csv")
train_data, validation_data, train_label, validation_label = train_test_split(
raw_train[DATA_COLUMN].tolist(),
raw_train[LABEL_COLUMN].tolist(),
test_size=.2,
shuffle=True
)
# --------------------------------------------------------------------------------
# Prepare TF dataset
# --------------------------------------------------------------------------------
train_dataset = tf.data.Dataset.from_tensor_slices((
dict(tokenize(train_data)), # Convert BatchEncoding instance to dictionary
train_label
)).shuffle(1000).batch(BATCH_SIZE).prefetch(1)
validation_dataset = tf.data.Dataset.from_tensor_slices((
dict(tokenize(validation_data)),
validation_label
)).batch(BATCH_SIZE).prefetch(1)
# --------------------------------------------------------------------------------
# training
# --------------------------------------------------------------------------------
model = TFDistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=NUM_LABELS
)
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
model.fit(
x=train_dataset,
y=None,
validation_data=validation_dataset,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
)
3rd approach
Basics
Please note that the images are taken from A Visual Guide to Using BERT for the First Time and modified.
Tokenizer
Tokenizer generates the instance of BatchEncoding which can be used like a Python dictionary and the input to the BERT model.
Holds the output of the encode_plus() and batch_encode() methods (tokens, attention_masks, etc).
This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes utility methods to map from word/character space to token space.
Parameters
- data (dict) – Dictionary of lists/arrays/tensors returned by the encode/batch_encode methods (‘input_ids’, ‘attention_mask’, etc.).
The data
attribute of the class is the tokens generated which has input_ids
and attention_mask
elements.
input_ids
The input ids are often the only required parameters to be passed to the model as input. They are token indices, numerical representations of tokens building the sequences that will be used as input by the model.
attention_mask
This argument indicates to the model which tokens should be attended to, and which should not.
If the attention_mask is 0
, the token id is ignored. For instance if a sequence is padded to adjust the sequence length, the padded words should be ignored hence their attention_mask are 0.
Special Tokens
BertTokenizer addes special tokens, enclosing a sequence with [CLS]
and [SEP]
. [CLS]
represents Classification and [SEP]
separates sequences. For Question Answer or Paraphrase tasks, [SEP]
separates the two sentences to compare.
BertTokenizer
- cls_token (str, optional, defaults to "[CLS]")
The Classifier Token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.
- sep_token (str, optional, defaults to "[SEP]")
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens.
A Visual Guide to Using BERT for the First Time show the tokenization.
[CLS]
The embedding vector for [CLS]
in the output from the base model final layer represents the classification that has been learned by the base model. Hence feed the embedding vector of [CLS]
token into the classification layer added on top of the base model.
The first token of every sequence is always a special classification token ([CLS])
. The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks. Sentence pairs are packed together into a single sequence. We differentiate the sentences in two ways. First, we separate them with a special token ([SEP]). Second, we add a learned embedding to every token indicating whether it belongs to sentence A or sentence B.
The model structure will be illustrated as below.
Vector size
In the model distilbert-base-uncased
, each token is embedded into a vector of size 768. The shape of the output from the base model is (batch_size, max_sequence_length, embedding_vector_size=768)
. This accords with the BERT paper about the BERT/BASE model (as indicated in distilbert-base-uncased).
BERT/BASE (L=12, H=768, A=12, Total Parameters=110M) and BERT/LARGE (L=24, H=1024, A=16, Total Parameters=340M).
Base Model - TFDistilBertModel
TFDistilBertModel class to instantiate the base DistilBERT model without any specific head on top (as opposed to other classes such as TFDistilBertForSequenceClassification that do have an added classification head).
We do not want any task-specific head attached because we simply want the pre-trained weights of the base model to provide a general understanding of the English language, and it will be our job to add our own classification head during the fine-tuning process in order to help the model distinguish between toxic comments.
TFDistilBertModel
generates an instance of TFBaseModelOutput
whose last_hidden_state
parameter is the output from the model last layer.
TFBaseModelOutput([(
'last_hidden_state',
<tf.Tensor: shape=(batch_size, sequence_lendgth, 768), dtype=float32, numpy=array([[[...]]], dtype=float32)>
)])
Parameters
- last_hidden_state (tf.Tensor of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.
Implementation
Python modules
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from transformers import (
DistilBertTokenizerFast,
TFDistilBertModel,
)
Configuration
TIMESTAMP = datetime.datetime.now().strftime("%Y%b%d%H%M").upper()
DATA_COLUMN = 'text'
LABEL_COLUMN = 'category_index'
MAX_SEQUENCE_LENGTH = 512 # Max length allowed for BERT is 512.
NUM_LABELS = len(raw_train[LABEL_COLUMN].unique())
MODEL_NAME = 'distilbert-base-uncased'
NUM_BASE_MODEL_OUTPUT = 768
# Flag to freeze base model
FREEZE_BASE = True
# Flag to add custom classification heads
USE_CUSTOM_HEAD = True
if USE_CUSTOM_HEAD == False:
# Make the base trainable when no classification head exists.
FREEZE_BASE = False
BATCH_SIZE = 16
LEARNING_RATE = 1e-2 if FREEZE_BASE else 5e-5
L2 = 0.01
Tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)
def tokenize(sentences, max_length=MAX_SEQUENCE_LENGTH, padding='max_length'):
"""Tokenize using the Huggingface tokenizer
Args:
sentences: String or list of string to tokenize
padding: Padding method ['do_not_pad'|'longest'|'max_length']
"""
return tokenizer(
sentences,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="tf"
)
Input layer
The base model expects input_ids
and attention_mask
whose shape is (max_sequence_length,)
. Generate Keras Tensors for them with Input
layer respectively.
# Inputs for token indices and attention masks
input_ids = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='input_ids')
attention_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='attention_mask')
Base model layer
Generate the output from the base model. The base model generates TFBaseModelOutput
. Feed the embedding of [CLS]
to the next layer.
base = TFDistilBertModel.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS
)
# Freeze the base model weights.
if FREEZE_BASE:
for layer in base.layers:
layer.trainable = False
base.summary()
# [CLS] embedding is last_hidden_state[:, 0, :]
output = base([input_ids, attention_mask]).last_hidden_state[:, 0, :]
Classification layers
if USE_CUSTOM_HEAD:
# -------------------------------------------------------------------------------
# Classifiation leayer 01
# --------------------------------------------------------------------------------
output = tf.keras.layers.Dropout(
rate=0.15,
name="01_dropout",
)(output)
output = tf.keras.layers.Dense(
units=NUM_BASE_MODEL_OUTPUT,
kernel_initializer='glorot_uniform',
activation=None,
name="01_dense_relu_no_regularizer",
)(output)
output = tf.keras.layers.BatchNormalization(
name="01_bn"
)(output)
output = tf.keras.layers.Activation(
"relu",
name="01_relu"
)(output)
# --------------------------------------------------------------------------------
# Classifiation leayer 02
# --------------------------------------------------------------------------------
output = tf.keras.layers.Dense(
units=NUM_BASE_MODEL_OUTPUT,
kernel_initializer='glorot_uniform',
activation=None,
name="02_dense_relu_no_regularizer",
)(output)
output = tf.keras.layers.BatchNormalization(
name="02_bn"
)(output)
output = tf.keras.layers.Activation(
"relu",
name="02_relu"
)(output)
Softmax Layer
output = tf.keras.layers.Dense(
units=NUM_LABELS,
kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.l2(l2=L2),
activation='softmax',
name="softmax"
)(output)
Final Custom Model
name = f"{TIMESTAMP}_{MODEL_NAME.upper()}"
model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output, name=name)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
metrics=['accuracy']
)
model.summary()
---
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_ids (InputLayer) [(None, 256)] 0
__________________________________________________________________________________________________
attention_mask (InputLayer) [(None, 256)] 0
__________________________________________________________________________________________________
tf_distil_bert_model (TFDistilB TFBaseModelOutput(la 66362880 input_ids[0][0]
attention_mask[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, 768) 0 tf_distil_bert_model[1][0]
__________________________________________________________________________________________________
01_dropout (Dropout) (None, 768) 0 tf.__operators__.getitem_1[0][0]
__________________________________________________________________________________________________
01_dense_relu_no_regularizer (D (None, 768) 590592 01_dropout[0][0]
__________________________________________________________________________________________________
01_bn (BatchNormalization) (None, 768) 3072 01_dense_relu_no_regularizer[0][0
__________________________________________________________________________________________________
01_relu (Activation) (None, 768) 0 01_bn[0][0]
__________________________________________________________________________________________________
02_dense_relu_no_regularizer (D (None, 768) 590592 01_relu[0][0]
__________________________________________________________________________________________________
02_bn (BatchNormalization) (None, 768) 3072 02_dense_relu_no_regularizer[0][0
__________________________________________________________________________________________________
02_relu (Activation) (None, 768) 0 02_bn[0][0]
__________________________________________________________________________________________________
softmax (Dense) (None, 2) 1538 02_relu[0][0]
==================================================================================================
Total params: 67,551,746
Trainable params: 1,185,794
Non-trainable params: 66,365,952 <--- Base BERT model is frozen
Data allocation
# --------------------------------------------------------------------------------
# Split data into training and validation
# --------------------------------------------------------------------------------
raw_train = pd.read_csv("./train.csv")
train_data, validation_data, train_label, validation_label = train_test_split(
raw_train[DATA_COLUMN].tolist(),
raw_train[LABEL_COLUMN].tolist(),
test_size=.2,
shuffle=True
)
# X = dict(tokenize(train_data))
# Y = tf.convert_to_tensor(train_label)
X = tf.data.Dataset.from_tensor_slices((
dict(tokenize(train_data)), # Convert BatchEncoding instance to dictionary
train_label
)).batch(BATCH_SIZE).prefetch(1)
V = tf.data.Dataset.from_tensor_slices((
dict(tokenize(validation_data)), # Convert BatchEncoding instance to dictionary
validation_label
)).batch(BATCH_SIZE).prefetch(1)
Train
# --------------------------------------------------------------------------------
# Train the model
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
# Input data x can be a dict mapping input names to the corresponding array/tensors,
# if the model has named inputs. Beware of the "names". y should be consistent with x
# (you cannot have Numpy inputs and tensor targets, or inversely).
# --------------------------------------------------------------------------------
history = model.fit(
x=X, # dictionary
# y=Y,
y=None,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
validation_data=V,
)
To implement the 1st approach, change the configuration as below.
USE_CUSTOM_HEAD = False
Then FREEZE_BASE
is changed to False
and LEARNING_RATE
is changed to 5e-5
which will run Further Pre-training on the base BERT model.
Saving the model
For the 3rd approach, saving the model will cause issues. The save_pretrained method of the Huggingface Model cannot be used as the model is not a direct sub class from of Huggingface PreTrainedModel.
Keras save_model causes an error with the default save_traces=True
, or causes a different error with save_traces=True
when loading the model with Keras load_model.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-71-01d66991d115> in <module>()
----> 1 tf.keras.models.load_model(MODEL_DIRECTORY)
11 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/load.py in _unable_to_call_layer_due_to_serialization_issue(layer, *unused_args, **unused_kwargs)
865 'recorded when the object is called, and used when saving. To manually '
866 'specify the input shape/dtype, decorate the call function with '
--> 867 '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
868
869
ValueError: Cannot call custom layer tf_distil_bert_model of type <class 'tensorflow.python.keras.saving.saved_model.load.TFDistilBertModel'>, because the call function was not serialized to the SavedModel.Please try one of the following methods to fix this issue:
(1) Implement `get_config` and `from_config` in the layer/model class, and pass the object to the `custom_objects` argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize
(2) Ensure that the subclassed model or layer overwrites `call` and not `__call__`. The input shape and dtype will be automatically recorded when the object is called, and used when saving. To manually specify the input shape/dtype, decorate the call function with `@tf.function(input_signature=...)`.
Only Keras Model save_weights worked as far as I tested.
Experiments
As far as I tested with Toxic Comment Classification Challenge, the 1st approach gave better recall (identify true toxic comment, true non-toxic comment). Code can be accessed as below. Please provide correction/suggestion if anything.
Related