Multi-label classification with BERT


Posted on Fri 29 January 2021

Transformers and their offsprings (BERT, GPT-3, T5, ...) have revolutionized NLP and, in my opinion, seem to take us closer to Artificial General Intelligence. Thanks to Hugging Face 馃, fine-tuning transformers on your data set now boils down to a couple of lines of Python code.

BERT's family

Some use cases still require a bit of tinkering, like the multi-label classification problem I had to tackle recently. In this article I'm sharing my TensorFlow implementation which is portable to any other transformer you may prefer. If you are more of a PyTorch aficionado, feel free to have a look at this blog post instead.

Today's post is not theoretical. If you want to dive into the inner workings of transformers there's a lot of good presentations on Transformers out there, notably Transformer illustrated, Transformers from scratch, and the annotated Transformer.

In my experience exploring the Transformers literature, the Transformer itself is the most complex component and the one that took me the most time to understand. With a solid understanding of the Transformer subsequent models built on top of it are fairly easy to grasp (like BERT, the GPT family, and maaaaaannnny others).

In this blog post I fine-tune DistillBERT (a smaller version of BERT with very close performances) on the Toxic Comment Classification Challenge. This challenge consists in tagging Wikipedia comments according to several "toxic behavior" labels. The task is a multi-label classification problem because a single comment can have zero, one, or up to six tags.

As you'll see below, I simply fine-tuned the model on a GPU (thanks to Colab) and achieved very good performances in less than an hour.

Import and prepare the dataset

In [1]:
!pip uninstall -y -q kaggle && pip install -q kaggle 
# Needed to get the latest version of the Kaggle CLI

from getpass import getpass
import os

# We'll use the Kaggle-CLI to download the dataset
# To create an authentication token on Kaggle check
# You'll also have to accept the competition rules here: 

os.environ["KAGGLE_USERNAME"] = getpass(prompt='Kaggle username: ')
os.environ["KAGGLE_KEY"] = getpass(prompt='Token: ')

!kaggle --version
!kaggle competitions download -c jigsaw-toxic-comment-classification-challenge
!unzip && unzip
Kaggle username: 路路路路路路路路路路
Token: 路路路路路路路路路路

The above cell is just meant to download and unpack the data, now let's load and prepare the dataset.

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split

dataset = pd.read_csv("train.csv")
texts = list(dataset["comment_text"])
label_names = dataset.drop(["id", "comment_text"], axis=1).columns
labels = dataset[label_names].values

train_texts, test_texts, train_labels, test_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42

sample_idx = 23
print(f'Sample: "{train_texts[sample_idx]}"')
print(f"Labels: {pd.Series(train_labels[sample_idx], label_names).to_dict()}")
Sample: "what the fuck who deleted the spider loc and hot rod sections fucking wikipedia stupid ass ignorant people can we get it back ?"
Labels: {'toxic': 1, 'severe_toxic': 0, 'obscene': 1, 'threat': 0, 'insult': 1, 'identity_hate': 0}

A minimalistic Exploratory Data Analysis

Here we'll just have a look at the texts length distribution. That will help us choose a reasonable cut-off for number of words, in order to speed up training (the maximum is 512 for BERT).

In [3]:
import seaborn as sns

text_lengths = [len(t.split()) for t in train_texts]
ax = sns.histplot(data=text_lengths, kde=True, stat="density")
ax.set_title("Texts length distribution (number of words)");

Let's choose a cutoff of 200 words, since most texts are shorter than this.

You'll see below that the labels are pretty imbalanced, it gives us an idea of the order of magnitude of what a decent accuracy could be.

In [4]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)

# Labels distribution barplot
labels_ratio = dataset[label_names].mean()
labels_ratio.plot(kind="bar", ax=ax1)
ax1.yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=False))
for p in ax1.patches:
    ax1.set_ylim(0, 1.1 * labels_ratio.max())
    ax1.annotate(f"{p.get_height():.2%}", (p.get_x() + 0.005, p.get_height() + 0.002))
ax1.set_title("Labels ratio");

# Labels correlation heatmap
ax2 = sns.heatmap(dataset[label_names].corr())
ax2.set_title("Labels correlation")

Note that labels are largely correlated. It should make traning easier since the patterns learned by the transformer will generalize accross labels.

For benchmark purpose it is always useful to know the performance of a dummy classifier that always predicts the label ratio as the label probability. Note that for multi-label classification we make a distinction between the accuracy (all predicted labels are correct) and the binary accuracy (how many individual labels are correct). The latter is always larger than the former because it is harder to predict all labels correctly. In the following I only compute the binary accuracy.

In [23]:
import numpy as np  
from sklearn.dummy import DummyClassifier
from sklearn.metrics import log_loss, average_precision_score

pd.set_option("display.precision", 3)

dummy = DummyClassifier(strategy="prior"), train_labels)
y_pred = dummy.predict(test_texts)
y_prob = dummy.predict_proba(test_texts)
y_prob = np.array(y_prob)[:, :, 1].T

def compute_metrics(y_true: np.array, y_prob: np.array) -> pd.Series:
    """Compute several performance metrics for multi-label classification. """
    y_pred = y_prob.round()
    metrics = dict()
    metrics["Multi-label accuracy"] = np.all(y_pred == y_true, axis=1).mean()
    metrics["Binary accuracy"] = (y_pred == y_true).mean()
    metrics["Loss"] = log_loss(y_true, y_prob)
    metrics["Average Precision"] = average_precision_score(y_true, y_prob)
    return pd.Series(metrics)

evaluation = compute_metrics(test_labels, y_prob).to_frame(name="Dummy")
Multi-label accuracy 0.898
Binary accuracy 0.963
Loss 0.302
Average Precision 0.037

Setting a baseline

Let's setup a baseline with a rather classic NLP modelling technique: a TF-IDF vectorization step followed by a regularized logistic regression. Despite its simplicity, this approach often give good results.

In [24]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline

# We take into account unigrams and bigrams that occur at least 10 times 
# in the train set, but less than 50 % of the time
tfidf = TfidfVectorizer(min_df=10, max_df=0.5, ngram_range=(1, 2))

# We set the (L2) regularization weight to 1/2 (inverse of C)
classifier = OneVsRestClassifier(LogisticRegression(C=2., max_iter=1000))

baseline = make_pipeline(tfidf, classifier), train_labels)

y_prob = baseline.predict_proba(test_texts)

evaluation["Baseline"] = compute_metrics(test_labels, y_prob)
Dummy Baseline
Multi-label accuracy 0.898 0.919
Binary accuracy 0.963 0.981
Loss 0.302 0.281
Average Precision 0.037 0.631

So we have a small improvement over the random baseline, let's se how transformers perform on this task.

Fine-tuning DistillBERT

The Transformers package provides pre-trained transformer-based models, plus the corresponding pre-processing and tokenizing functions (the tokenizers even have optimized implementations in Rust!).

In [7]:
!pip install -q transformers > /dev/null

import transformers
print(f"Transformers package version: {transformers.__version__}")
Transformers package version: 4.2.2
In [8]:
import tensorflow as tf
from transformers import TFDistilBertForSequenceClassification, \
    DistilBertConfig, DistilBertTokenizerFast

MODEL_NAME = 'distilbert-base-uncased'
MAX_LENGTH = 200  # We truncate anything after the 200-th word to speed up training

# The configuration is not needed if you don't have to customize the 
# network architecture. Here we will need it to replacee the output of the model
# with a multi-label prediction layer (i.e. sigmoid activations + binary cross-entropy
# instead of softmax + categorical cross-entropy of multi-class classification)
config = DistilBertConfig.from_pretrained(MODEL_NAME)

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

train_encodings = tokenizer(train_texts, truncation=True, padding=True, 
                            max_length=MAX_LENGTH, return_tensors="tf")
test_encodings = tokenizer(test_texts, truncation=True, padding=True, 
                           max_length=MAX_LENGTH, return_tensors="tf")

# Create TensorFlow datasets to feed the model for training and evaluation
train_dataset =, train_labels))
test_dataset =, test_labels))

# Tokenizer output example
sample_text = "I have changed the headers to small letters, since I was basically..."

'[CLS] i have changed the headers to small letters, since i was basically... [SEP]'

The "[CLS]" special token is prepended to each text and will be used for classification, and a "[SEP]" token is appended.

Note that thanks to BERT adding customizable special tokens, you can also take into account custom vocabularies that maybe specific to your dataset (e.g. domain-specific tags, or unusual characters with a specific meaning). To do so, you just have to add the following line:

tokenizer.add_special_tokens({"additional_special_tokens": ["[unused1]"]})

and map your custom symbols to "[unused1]" (or "[unused2]", ..., up to "[unused999]"). We won't need it here.

Now that the tokenizer is available, we have to customize the output of the BERT model for our multi-label problem.

In [9]:
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.layers import Input, Dropout, Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

transformer_model = TFDistilBertForSequenceClassification.from_pretrained(
    MODEL_NAME, output_hidden_states=False

bert = transformer_model.layers[0]

# The input is a dictionary of word identifiers 
input_ids = Input(shape=(MAX_LENGTH,), name='input_ids', dtype='int32')
inputs = {'input_ids': input_ids}

# Here we select the representation of the first token ([CLS]) for classification
# (a.k.a. "pooled representation")
bert_model = bert(inputs)[0][:, 0, :] 

# Add a dropout layer and the output layer
dropout = Dropout(config.dropout, name='pooled_output')
pooled_output = dropout(bert_model, training=False)
output = Dense(
    activation="sigmoid",  # Choose a sigmoid for multi-label classification

model = Model(inputs=inputs, outputs=output, name='BERT_MultiLabel')
Model: "BERT_MultiLabel"
Layer (type)                 Output Shape              Param #   
input_ids (InputLayer)       [(None, 200)]             0         
distilbert (TFDistilBertMain TFBaseModelOutput(last_hi 66362880  
tf.__operators__.getitem (Sl (None, 768)               0         
pooled_output (Dropout)      (None, 768)               0         
output (Dense)               (None, 6)                 4614      
Total params: 66,367,494
Trainable params: 66,367,494
Non-trainable params: 0

So we will fine-tune ~66M parameters on our dataset. That may sound like a lot, and one may be concerned with overfitting risks, but BERT has proven very robust to fine-tuning.

Now we can train the model in a few lines of code!

In [20]:
def multi_label_accuracy(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """For multi-label classification, one has to define a custom
    acccuracy function because neither tf.keras.metrics.Accuracy nor
    tf.keras.metrics.CategoricalAccuracy evaluate the number of 
    exact matches.

    >>> from tensorflow.keras import metrics
    >>> y_true = tf.convert_to_tensor([[1., 1.]])
    >>> y_pred = tf.convert_to_tensor([[1., 0.]])
    >>> metrics.Accuracy()(y_true, y_pred).numpy()
    >>> metrics.CategoricalAccuracy()(y_true, y_pred).numpy()
    >>> multi_label_accuracy(y_true, y_pred).numpy()
    y_pred = tf.math.round(y_pred)
    exact_matches = tf.math.reduce_all(y_pred == y_true, axis=1)
    exact_matches = tf.cast(exact_matches, tf.float32)
    return tf.math.reduce_mean(exact_matches)

loss = BinaryCrossentropy()
optimizer = Adam(5e-5)
metrics = [
    AUC(name="average_precision", curve="PR", multi_label=True)
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
training_history =
    train_dataset.shuffle(1000).batch(16), epochs=2, batch_size=16, 
Epoch 1/2
7979/7979 [==============================] - 1112s 138ms/step - loss: 0.0308 - multi_label_accuracy: 0.9397 - binary_accuracy: 0.9877 - average_precision: 0.7518 - val_loss: 0.0442 - val_multi_label_accuracy: 0.9236 - val_binary_accuracy: 0.9837 - val_average_precision: 0.6748
Epoch 2/2
7979/7979 [==============================] - 1098s 138ms/step - loss: 0.0255 - multi_label_accuracy: 0.9486 - binary_accuracy: 0.9896 - average_precision: 0.8037 - val_loss: 0.0459 - val_multi_label_accuracy: 0.9220 - val_binary_accuracy: 0.9835 - val_average_precision: 0.6720

The model achieves a good accuracy compared to the random predictor and the baseline model.

Let's run the evaluate method on the test set once more, as a sanity check.

In [28]:
benchmarks = model.evaluate(
    test_dataset.batch(16), return_dict=True, batch_size=16
evaluation["DistillBERT"] = [
    benchmarks[k] for k in 
    ["multi_label_accuracy", "binary_accuracy", "loss", "average_precision"]
Dummy Baseline DistillBERT
Multi-label accuracy 0.898 0.919 0.922
Binary accuracy 0.963 0.981 0.984
Loss 0.302 0.281 0.046
Average Precision 0.037 0.631 0.672

The average precision went up by a substantial amount. This model is pretty good. That is usually the point at which you would start a deeper analysis (other metrics, confusion matrix, optimal thresholds, individual examples, feature importance, ...) but this is beyond the scope of this post.

We can now save our trained model. In the following cell the first few lines of code are specific to using Colab and mount your Google Drive as a local drive, for persistence (very convenient 馃槉).

In [43]:
from google.colab import drive

drive.mount('/gdrive')  # A new tab will open and you will have to accept
# the conditions and copy paste the token below 
BASE_PATH = "/gdrive/My Drive/toxic_comments_transformer"

if not os.path.exists(BASE_PATH):


Here's an example of how you would use your shiny BERT model in production, and a measure of the typical latency that we should expect.

In [44]:
from time import time
from tensorflow.keras.models import load_model
from transformers import DistilBertTokenizerFast

# Mimicking a production scenario: load the model and tokenizer
model = load_model(f"{BASE_PATH}/fine_tuned_distilbert", 
                   custom_objects={"multi_label_accuracy": multi_label_accuracy})
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def score_text(text, model=model, tokenizer=tokenizer):
    padded_encodings = tokenizer.encode_plus(
        max_length=MAX_LENGTH, # truncates if len(s) > max_length
    return model(padded_encodings["input_ids"]).numpy()

score_text("dummy")  # running a dummy prediction as a work-around the extra latency 
# of the first prediction of a loaded TensorFlow model.

text = """I am a nice Wikipedia user, I mean no harm, 
I will not insult anybody or be offensive anyhow."""

t0 = time()
scores = score_text(text)[0]
latency = time() - t0

scores = pd.Series(scores, label_names, name="scores")
print(f"\nLatency: {latency:.3f} seconds")
toxic          2.861e-03
severe_toxic   8.741e-05
obscene        3.906e-04
threat         6.823e-04
insult         5.880e-05
identity_hate  2.378e-05

Latency: 0.070 seconds

Pretty fast! We can now use our shiny-fine-tuned-transformer in production 馃帀

Share on: