TL;DR
Following the Jindtrich's answer I implement a context-aware nearest neighbor searcher. The full code is available in my Github gist
It requires a BERT-like model (I use bert-embeddings) and a corpus of sentences (I took a small one from here), processes each sentence, and stores contextual token embeddings in an efficiently searchable data structure (I use KDTree, but feel free to choose FAISS or HNSW or whatever).
Examples
The model is constructed as follows:
# preparing the model
storage = ContextNeighborStorage(sentences=all_sentences, model=bert)
storage.process_sentences()
storage.build_search_index()
Then it can be queried for contextually most similar words, like
# querying the model
distances, neighbors, contexts = storage.query(
query_sent='It is a power bank.', query_word='bank', k=5)
In this example, the nearest neighbor would be the word "bank" in the sentence "Finally, there’s a second version of the Duo that incorporates a 2000mAH power bank, the Flip Power World.".
If, however, we look for the same word with another context, like
distances, neighbors, contexts = storage.query(
query_sent='It is an investment bank.', query_word='bank', k=5)
then the nearest neighbor will be in the sentence "The bank also was awarded a 5-star, Superior Bauer rating for Dec. 31, 2017, financial data."
If we don't want to retrieve the word "bank" or its derivative word, we can filter them out
distances, neighbors, contexts = storage.query(
query_sent='It is an investment bank.', query_word='bank', k=5, filter_same_word=True)
and then the nearest neighbor will be the word "finance" in the sentence "Cahal is Vice Chairman of Deloitte UK and Chairman of the Advisory Corporate Finance business from 2014 (previously led the business from 2005).".
Application in NER
One of the cool applications of this approach is interpretable named entity recognition. We can fill the search index with IOB-labeled examples, and then use retrieved examples to infer the right label for the query word.
For example, the nearest neighbor of "Bezos announced that its two-day delivery service, Amazon Prime, had surpassed 100 million subscribers worldwide." is "Expanded third-party integration including Amazon Alexa, Google Assistant, and IFTTT.".
But for "The Atlantic has sufficient wave and tidal energy to carry most of the Amazon's sediments out to sea, thus the river does not form a true delta" the nearest neighbor is "And, this year our stories are the work of traveling from Brazil’s Iguassu Falls to a chicken farm in Atlanta".
So if these neighbors were labeled, we could infer that in the first context "Amazon" is an ORGanization, but in the second one it is a LOCation.
The code
Here is the class that does this work:
import numpy as np
from sklearn.neighbors import KDTree
from tqdm.auto import tqdm
class ContextNeighborStorage:
def __init__(self, sentences, model):
self.sentences = sentences
self.model = model
def process_sentences(self):
result = self.model(self.sentences)
self.sentence_ids = []
self.token_ids = []
self.all_tokens = []
all_embeddings = []
for i, (toks, embs) in enumerate(tqdm(result)):
for j, (tok, emb) in enumerate(zip(toks, embs)):
self.sentence_ids.append(i)
self.token_ids.append(j)
self.all_tokens.append(tok)
all_embeddings.append(emb)
all_embeddings = np.stack(all_embeddings)
# we normalize embeddings, so that euclidian distance is equivalent to cosine distance
self.normed_embeddings = (all_embeddings.T / (all_embeddings**2).sum(axis=1) ** 0.5).T
def build_search_index(self):
# this takes some time
self.indexer = KDTree(self.normed_embeddings)
def query(self, query_sent, query_word, k=10, filter_same_word=False):
toks, embs = self.model([query_sent])[0]
found = False
for tok, emb in zip(toks, embs):
if tok == query_word:
found = True
break
if not found:
raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
emb = emb / sum(emb**2)**0.5
if filter_same_word:
initial_k = max(k, 100)
else:
initial_k = k
di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
distances = []
neighbors = []
contexts = []
for i, index in enumerate(idx.ravel()):
token = self.all_tokens[index]
if filter_same_word and (query_word in token or token in query_word):
continue
distances.append(di.ravel()[i])
neighbors.append(token)
contexts.append(self.sentences[self.sentence_ids[index]])
if len(distances) == k:
break
return distances, neighbors, contexts