State of the Art Semantic Textual Similarity
With Google newly released BERT. There is lots of extension of models available online(ie. huggingface/transformers, and UKPLab/sentence-transformers). This article focuses of plug and play code snippets and I will provide a sample code that you can modify to your needs. But keep in mind that these models are supposed to be fine-tuned for your specific use cases.
After that we will setup a facebookresearch/faiss index for scalability of deployment.
I attempted to use default BERT model from huggingface/transformers but there is a problem, there is no good embedding models return from the models(read return type of https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel for more). Fine-tuning process is absolutely required, but there is usually either no great dataset , require too deep understanding and large computing resources, which I don’t not posses. Besides, because I am in Hong Kong, most of the task required Chinese and English multilingual models, I choose to use this bert-base-multilingual-cased
from https://huggingface.co/transformers/pretrained_models.html.
Installation
pip install transformers, pytorch, hanziconv
Traditional to Modern Chinese
In Hong Kong most of the application are in traditional chinese, for this I used the tokenizer to test if the encoding is the same
tokenizer.encode('爱')
tokenizer.encode('愛')
> [101, 5383, 102]
> [101, 3910, 102]
The response is they are different.
In that case my dataset is a mixed of tranditional and modern Chinese, I decided to use a package berniey/hanziconv for this conversion.
from hanziconv import HanziConv as t2mc
t2mc.toSimplified('愛')
> '爱'
Using bert-base-multilingual-cased
bert-base-multilingual-cased
from huggingface/transformers.
from transformers import BertModel, BertTokenizer
import torch
bert_model_shortcut = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert_model_shortcut)
model = BertModel.from_pretrained(bert_model_shortcut)
encoded_tokens = tokenizer.encode("Here is some text to encode", add_special_tokens=True)
input_ids = torch.tensor([encoded_tokens])
with torch.no_grad():
last_hidden_states = model(input_ids)
last_hidden_states.shape
The models was found to actually not trained with semantic textual similarity in mind.
This become a problem for us, because we want something that can measure similarity between texts.
Therefore after some searching I found UKPLab/sentence-transformers.
Using UKPLab/sentence-transformers’s distiluse-base-multilingual-cased
How this STS library works?
- First start with a model you can choose the type of model for transfer learning from sentence transformers available models
- Using the model
- Apply on text to get embedding(a matrix).
model.encode([text])[0]
- Two embeddings are closely related, if their embeddings have small mean squared error. Which means semantically they are identical.
- Apply on text to get embedding(a matrix).
- Training procedure
- To use
model.train(...)
we need thedataset
parameter. dataset
is the format ofsentence_transformers.datasets.SentencesDataset
- constructor require a list of 3 item for each data point, and you can construct this with a constructor from
sentence_transformers.readers.InputExample
uid
unique id, this can be formed by yourselfrandom
library from pythonsentences
an array of sentences, in which case it is usually twoweight
this weight can be positive or negative, positive implies a closely related, negative implies a negative relationship, which the model will try not to replicate.
- constructor require a list of 3 item for each data point, and you can construct this with a constructor from
- To use
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers.datasets import SentencesDataset
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import math
model = SentenceTransformer('distiluse-base-multilingual-cased')
def testing():
sentence_embedding = model.encode([text])[0]
return sentence_embedding
input_example_list = [
InputExample(
'1234',
[
'I am a cat',
'I am a persian cat'
],
1
),
InputExample(
'1235',
[
'I am a cat',
'I am a husky'
],
-1
),
]
sentence_dataset = SentencesDataset(examples=input_example_list, model=model)
def train(sentence_dataset, batch_size=32, num_epochs=10,):
data_loader = DataLoader(sentence_dataset, shuffle=True, batch_size=batch_size)
evaluator = EmbeddingSimilarityEvaluator(data_loader)
warm_up_steps = math.ceil(
sentence_dataset.labels.shape[0] * num_epochs / batch_size * 0.1) # 10% of train data for warm-up
train_loss = losses.CosineSimilarityLoss(model=model)
model.fit(
train_objectives=[(data_loader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warm_up_steps
)
Indexing the embedding for production
Because these embedding is large, searching one by one is extremely slow. If we have a larger dataset of similar items to match we will face problems. Therefore I seek out to find a indexer library and came across facebookresearch/faiss.
This library performs indexing and you can save the result onto a file for easy transfer. I suggest making a program that query another indexer program instead of loading the file everytime you want to do a search. Because the reading of a previously indexed result is very slow.
import faiss
DIMENSION = sentence_embedding.shape[1]
quantizer = faiss.IndexFlatL2(DIMENSION)
indexer = faiss.IndexIVFFlat(quantizer, DIMENSION, 10)
data, label = [
[embedding_1,embedding_2],
['ID123', 'ID124']
]
# first step of training
indexer.train(data)
# insert data
indexer.add(data, label)
# search
results = indexer.search(embedding_3, count=20) # count is the number of results.
# the results is the form of scores[], labels[] <- labels can be 'ID123' or 'ID124' in our case
# to read use `indexer = faiss.read_index(FILENAME)`
# to save use `faiss.write_index(indexer, FILENAME)`
#as an example (I try to create a class so anyone can use easily their library is still abit messy.)
Be careful when importing faiss
and pytorch
there is a bug for which the sequence of importing is important. I didn’t file a bug report because for most computer it does not cause that bug.
A smarter searching program can be created if you utilized these tools properly :D. Hope you enjoy the reading!