State of the Art Semantic Textual Similarity
With Google newly released BERT. There is lots of extension of models available online(ie. huggingface/transformers, and UKPLab/sentence-transformers). This article focuses of plug and play code snippets and I will provide a sample code that you can modify to your needs. But keep in mind that these models are supposed to be fine-tuned for your specific use cases.
After that we will setup a facebookresearch/faiss index for scalability of deployment.
I attempted to use default BERT model from huggingface/transformers but there is a problem, there is no good embedding models return from the models(read return type of https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel for more). Fine-tuning process is absolutely required, but there is usually either no great dataset , require too deep understanding and large computing resources, which I don’t not posses. Besides, because I am in Hong Kong, most of the task required Chinese and English multilingual models, I choose to use this bert-base-multilingual-cased from https://huggingface.co/transformers/pretrained_models.html.
Installation
pip install transformers, pytorch, hanziconv
Traditional to Modern Chinese
In Hong Kong most of the application are in traditional chinese, for this I used the tokenizer to test if the encoding is the same
tokenizer.encode('爱')
tokenizer.encode('愛')
> [101, 5383, 102]
> [101, 3910, 102]
The response is they are different.
In that case my dataset is a mixed of tranditional and modern Chinese, I decided to use a package berniey/hanziconv for this conversion.
from hanziconv import HanziConv as t2mc
t2mc.toSimplified('愛')
> '爱'
Using bert-base-multilingual-cased
bert-base-multilingual-cased from huggingface/transformers.
from transformers import BertModel, BertTokenizer
import torch
bert_model_shortcut = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert_model_shortcut)
model = BertModel.from_pretrained(bert_model_shortcut)
encoded_tokens = tokenizer.encode("Here is some text to encode", add_special_tokens=True)
input_ids = torch.tensor([encoded_tokens])
with torch.no_grad():
last_hidden_states = model(input_ids)
last_hidden_states.shape
The models was found to actually not trained with semantic textual similarity in mind.
This become a problem for us, because we want something that can measure similarity between texts.
Therefore after some searching I found UKPLab/sentence-transformers.
Using UKPLab/sentence-transformers’s distiluse-base-multilingual-cased
How this STS library works?
- First start with a model you can choose the type of model for transfer learning from sentence transformers available models
- Using the model
- Apply on text to get embedding(a matrix).
model.encode([text])[0] - Two embeddings are closely related, if their embeddings have small mean squared error. Which means semantically they are identical.
- Apply on text to get embedding(a matrix).
- Training procedure
- To use
model.train(...)we need thedatasetparameter. datasetis the format ofsentence_transformers.datasets.SentencesDataset- constructor require a list of 3 item for each data point, and you can construct this with a constructor from
sentence_transformers.readers.InputExample uidunique id, this can be formed by yourselfrandomlibrary from pythonsentencesan array of sentences, in which case it is usually twoweightthis weight can be positive or negative, positive implies a closely related, negative implies a negative relationship, which the model will try not to replicate.
- constructor require a list of 3 item for each data point, and you can construct this with a constructor from
- To use
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers.datasets import SentencesDataset
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import math
model = SentenceTransformer('distiluse-base-multilingual-cased')
def testing():
sentence_embedding = model.encode([text])[0]
return sentence_embedding
input_example_list = [
InputExample(
'1234',
[
'I am a cat',
'I am a persian cat'
],
1
),
InputExample(
'1235',
[
'I am a cat',
'I am a husky'
],
-1
),
]
sentence_dataset = SentencesDataset(examples=input_example_list, model=model)
def train(sentence_dataset, batch_size=32, num_epochs=10,):
data_loader = DataLoader(sentence_dataset, shuffle=True, batch_size=batch_size)
evaluator = EmbeddingSimilarityEvaluator(data_loader)
warm_up_steps = math.ceil(
sentence_dataset.labels.shape[0] * num_epochs / batch_size * 0.1) # 10% of train data for warm-up
train_loss = losses.CosineSimilarityLoss(model=model)
model.fit(
train_objectives=[(data_loader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warm_up_steps
)
Indexing the embedding for production
Because these embedding is large, searching one by one is extremely slow. If we have a larger dataset of similar items to match we will face problems. Therefore I seek out to find a indexer library and came across facebookresearch/faiss.
This library performs indexing and you can save the result onto a file for easy transfer. I suggest making a program that query another indexer program instead of loading the file everytime you want to do a search. Because the reading of a previously indexed result is very slow.
import faiss
DIMENSION = sentence_embedding.shape[1]
quantizer = faiss.IndexFlatL2(DIMENSION)
indexer = faiss.IndexIVFFlat(quantizer, DIMENSION, 10)
data, label = [
[embedding_1,embedding_2],
['ID123', 'ID124']
]
# first step of training
indexer.train(data)
# insert data
indexer.add(data, label)
# search
results = indexer.search(embedding_3, count=20) # count is the number of results.
# the results is the form of scores[], labels[] <- labels can be 'ID123' or 'ID124' in our case
# to read use `indexer = faiss.read_index(FILENAME)`
# to save use `faiss.write_index(indexer, FILENAME)`
#as an example (I try to create a class so anyone can use easily their library is still abit messy.)
Be careful when importing faiss and pytorch there is a bug for which the sequence of importing is important. I didn’t file a bug report because for most computer it does not cause that bug.
A smarter searching program can be created if you utilized these tools properly :D. Hope you enjoy the reading!
Eugene Low