machine_learning,

Sota Sts

Eugene Low Eugene Low Follow Jul 17, 2020 · 6 mins read
Sota Sts
Share this

State of the Art Semantic Textual Similarity

With Google newly released BERT. There is lots of extension of models available online(ie. huggingface/transformers, and UKPLab/sentence-transformers). This article focuses of plug and play code snippets and I will provide a sample code that you can modify to your needs. But keep in mind that these models are supposed to be fine-tuned for your specific use cases.

After that we will setup a facebookresearch/faiss index for scalability of deployment.

I attempted to use default BERT model from huggingface/transformers but there is a problem, there is no good embedding models return from the models(read return type of https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel for more). Fine-tuning process is absolutely required, but there is usually either no great dataset , require too deep understanding and large computing resources, which I don’t not posses. Besides, because I am in Hong Kong, most of the task required Chinese and English multilingual models, I choose to use this bert-base-multilingual-cased from https://huggingface.co/transformers/pretrained_models.html.

Installation

pip install transformers, pytorch, hanziconv 

Traditional to Modern Chinese

In Hong Kong most of the application are in traditional chinese, for this I used the tokenizer to test if the encoding is the same

tokenizer.encode('爱')
tokenizer.encode('愛')
> [101, 5383, 102]
> [101, 3910, 102]

The response is they are different.

In that case my dataset is a mixed of tranditional and modern Chinese, I decided to use a package berniey/hanziconv for this conversion.

from hanziconv import HanziConv as t2mc

t2mc.toSimplified('愛')
> '爱'

Using bert-base-multilingual-cased

bert-base-multilingual-cased from huggingface/transformers.

from transformers import BertModel, BertTokenizer
import torch

bert_model_shortcut = 'bert-base-multilingual-cased'

tokenizer = BertTokenizer.from_pretrained(bert_model_shortcut)
model = BertModel.from_pretrained(bert_model_shortcut)

encoded_tokens = tokenizer.encode("Here is some text to encode", add_special_tokens=True)

input_ids = torch.tensor([encoded_tokens])

with torch.no_grad():
    last_hidden_states = model(input_ids)

last_hidden_states.shape 

The models was found to actually not trained with semantic textual similarity in mind.

This become a problem for us, because we want something that can measure similarity between texts.

Therefore after some searching I found UKPLab/sentence-transformers.

Using UKPLab/sentence-transformers’s distiluse-base-multilingual-cased

How this STS library works?

  1. First start with a model you can choose the type of model for transfer learning from sentence transformers available models
  2. Using the model
    1. Apply on text to get embedding(a matrix). model.encode([text])[0]
    2. Two embeddings are closely related, if their embeddings have small mean squared error. Which means semantically they are identical.
  3. Training procedure
    1. To use model.train(...) we need the dataset parameter.
    2. dataset is the format of sentence_transformers.datasets.SentencesDataset
      1. constructor require a list of 3 item for each data point, and you can construct this with a constructor from sentence_transformers.readers.InputExample
      2. uid unique id, this can be formed by yourself random library from python
      3. sentences an array of sentences, in which case it is usually two
      4. weight this weight can be positive or negative, positive implies a closely related, negative implies a negative relationship, which the model will try not to replicate.
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers.datasets import SentencesDataset

from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

import math

model = SentenceTransformer('distiluse-base-multilingual-cased')

def testing():
	sentence_embedding = model.encode([text])[0]
    return sentence_embedding


input_example_list = [
	InputExample(
    	'1234',
        [
        	'I am a cat',
            'I am a persian cat'
        ],
        1
    ),
    InputExample(
    	'1235',
        [
        	'I am a cat',
            'I am a husky'
        ],
        -1
    ),
]
sentence_dataset = SentencesDataset(examples=input_example_list, model=model)
def train(sentence_dataset, batch_size=32, num_epochs=10,):
    data_loader = DataLoader(sentence_dataset, shuffle=True, batch_size=batch_size)
    evaluator = EmbeddingSimilarityEvaluator(data_loader)

    warm_up_steps = math.ceil(
        sentence_dataset.labels.shape[0] * num_epochs / batch_size * 0.1)  # 10% of train data for warm-up

    train_loss = losses.CosineSimilarityLoss(model=model)

    model.fit(
        train_objectives=[(data_loader, train_loss)],
        evaluator=evaluator,
        epochs=num_epochs,
        warmup_steps=warm_up_steps
    )

Indexing the embedding for production

Because these embedding is large, searching one by one is extremely slow. If we have a larger dataset of similar items to match we will face problems. Therefore I seek out to find a indexer library and came across facebookresearch/faiss.

This library performs indexing and you can save the result onto a file for easy transfer. I suggest making a program that query another indexer program instead of loading the file everytime you want to do a search. Because the reading of a previously indexed result is very slow.

import faiss

DIMENSION = sentence_embedding.shape[1]

quantizer = faiss.IndexFlatL2(DIMENSION)
indexer = faiss.IndexIVFFlat(quantizer, DIMENSION, 10)

data, label = [
	[embedding_1,embedding_2],
	['ID123', 'ID124']
]

# first step of training
indexer.train(data)

# insert data
indexer.add(data, label)

# search 
results = indexer.search(embedding_3, count=20) # count is the number of results.
# the results is the form of scores[], labels[] <- labels can be 'ID123' or 'ID124' in our case

# to read use `indexer = faiss.read_index(FILENAME)`
# to save use `faiss.write_index(indexer, FILENAME)`

#as an example (I try to create a class so anyone can use easily their library is still abit messy.)

Be careful when importing faiss and pytorch there is a bug for which the sequence of importing is important. I didn’t file a bug report because for most computer it does not cause that bug.

A smarter searching program can be created if you utilized these tools properly :D. Hope you enjoy the reading!

Eugene Low
Written by Eugene Low
Hi, I am Eugene!