Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit e1910c5

Browse files
Gary Laieric-haibin-lin
authored andcommitted
[FEATURE] offer load_w2v_binary method to load w2v binary file (#620)
* ✨ (w2v) offer load_w2v_binary method to load w2v binary file * 🐛 (w2v) expand user path * 🚧 (w2v) add preload flag. it decides to load default model or not. * 🐛 (w2v) preserve 0 for unknown tok * ✅ (w2v) add test for load_w2v_binary * ♻️ (w2v) If source is None, construct an empty Word2Vec class object. * ♻️ (w2v) add classmethod: from_binary as a constructor * 🎨 (w2v) rename from_binary -> from_w2v_binary * ♻️ (w2v) check source in w2v constructor * ♻️ (w2v) use file name extension to identify binary file * 🎨 (w2v) pyilnt: fix unused argument 'encoding' (unused-argument) * 📝 (w2v) update doc * 📝 (w2v) update description in w2v constructor
1 parent 6ec0c84 commit e1910c5

4 files changed

Lines changed: 554 additions & 6 deletions

File tree

src/gluonnlp/embedding/token_embedding.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,6 @@ def allow_extend(self):
448448
"""
449449
return self._allow_extend
450450

451-
452451
@property
453452
def unknown_lookup(self):
454453
"""Vector lookup for unknown tokens.
@@ -997,6 +996,8 @@ class Word2Vec(TokenEmbedding):
997996
----------
998997
source : str, default 'GoogleNews-vectors-negative300'
999998
The name of the pre-trained token embedding file.
999+
A binary pre-trained file outside from the source list can be used for this constructor by
1000+
passing the path to it which ends with .bin as file extension name.
10001001
embedding_root : str, default '$MXNET_HOME/embedding'
10011002
The root directory for storing embedding-related files.
10021003
MXNET_HOME defaults to '~/.mxnet'.
@@ -1019,10 +1020,105 @@ class Word2Vec(TokenEmbedding):
10191020
source_file_hash = C.WORD2VEC_NPZ_SHA1
10201021

10211022
def __init__(self, source='GoogleNews-vectors-negative300',
1022-
embedding_root=os.path.join(get_home_dir(), 'embedding'), **kwargs):
1023-
self._check_source(self.source_file_hash, source)
1024-
1023+
embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding='utf8',
1024+
**kwargs):
10251025
super(Word2Vec, self).__init__(**kwargs)
1026-
pretrained_file_path = self._get_file_path(self.source_file_hash, embedding_root, source)
1026+
if source.endswith('.bin'):
1027+
pretrained_file_path = os.path.expanduser(source)
1028+
self._load_w2v_binary(pretrained_file_path, encoding=encoding)
1029+
else:
1030+
self._check_source(self.source_file_hash, source)
1031+
pretrained_file_path = self._get_file_path(self.source_file_hash,
1032+
embedding_root, source)
1033+
self._load_embedding(pretrained_file_path, elem_delim=' ')
10271034

1028-
self._load_embedding(pretrained_file_path, elem_delim=' ')
1035+
def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'):
1036+
"""Load embedding vectors from a binary pre-trained token embedding file.
1037+
1038+
Parameters
1039+
----------
1040+
pretrained_file_path: str
1041+
The path to a binary pre-trained token embedding file end with .bin as file extension
1042+
name.
1043+
encoding: str
1044+
The encoding type of the file.
1045+
"""
1046+
self._idx_to_token = [self.unknown_token] if self.unknown_token else []
1047+
if self.unknown_token:
1048+
self._token_to_idx = DefaultLookupDict(C.UNK_IDX)
1049+
else:
1050+
self._token_to_idx = {}
1051+
self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token))
1052+
self._idx_to_vec = None
1053+
all_elems = []
1054+
tokens = set()
1055+
loaded_unknown_vec = None
1056+
pretrained_file_path = os.path.expanduser(pretrained_file_path)
1057+
with io.open(pretrained_file_path, 'rb') as f:
1058+
header = f.readline().decode(encoding=encoding)
1059+
vocab_size, vec_len = (int(x) for x in header.split())
1060+
if self.unknown_token:
1061+
# Reserve a vector slot for the unknown token at the very beggining
1062+
# because the unknown token index is 0.
1063+
all_elems.extend([0] * vec_len)
1064+
binary_len = np.dtype(np.float32).itemsize * vec_len
1065+
for line_num in range(vocab_size):
1066+
token = []
1067+
while True:
1068+
ch = f.read(1)
1069+
if ch == b' ':
1070+
break
1071+
if ch == b'':
1072+
raise EOFError('unexpected end of input; is count incorrect or file '
1073+
'otherwise damaged?')
1074+
if ch != b'\n': # ignore newlines in front of words (some binary files have)
1075+
token.append(ch)
1076+
try:
1077+
token = b''.join(token).decode(encoding=encoding)
1078+
except ValueError:
1079+
warnings.warn('line {} in {}: failed to decode. Skipping.'
1080+
.format(line_num, pretrained_file_path))
1081+
continue
1082+
elems = np.fromstring(f.read(binary_len), dtype=np.float32)
1083+
1084+
assert len(elems) > 1, 'line {} in {}: unexpected data format.'.format(
1085+
line_num, pretrained_file_path)
1086+
1087+
if token == self.unknown_token and loaded_unknown_vec is None:
1088+
loaded_unknown_vec = elems
1089+
tokens.add(self.unknown_token)
1090+
elif token in tokens:
1091+
warnings.warn('line {} in {}: duplicate embedding found for '
1092+
'token "{}". Skipped.'.format(line_num, pretrained_file_path,
1093+
token))
1094+
else:
1095+
assert len(elems) == vec_len, \
1096+
'line {} in {}: found vector of inconsistent dimension for token ' \
1097+
'"{}". expected dim: {}, found: {}'.format(line_num,
1098+
pretrained_file_path,
1099+
token, vec_len, len(elems))
1100+
all_elems.extend(elems)
1101+
self._idx_to_token.append(token)
1102+
self._token_to_idx[token] = len(self._idx_to_token) - 1
1103+
tokens.add(token)
1104+
self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len))
1105+
1106+
if self.unknown_token:
1107+
if loaded_unknown_vec is None:
1108+
self._idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
1109+
else:
1110+
self._idx_to_vec[C.UNK_IDX] = nd.array(loaded_unknown_vec)
1111+
1112+
@classmethod
1113+
def from_w2v_binary(cls, pretrained_file_path, encoding='utf8'):
1114+
"""Load embedding vectors from a binary pre-trained token embedding file.
1115+
1116+
Parameters
1117+
----------
1118+
pretrained_file_path: str
1119+
The path to a binary pre-trained token embedding file end with .bin as file extension
1120+
name.
1121+
encoding: str
1122+
The encoding type of the file.
1123+
"""
1124+
return cls(source=pretrained_file_path, encoding=encoding)

tests/unittest/train/test_embedding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def test_fasttext_embedding_load_binary_compare_vec():
8989
np.isclose(a=token_embedding_vec.idx_to_vec.asnumpy(),
9090
b=idx_to_vec.asnumpy(), atol=0.001))
9191
assert all(token in model for token in token_embedding_vec.idx_to_token)
92+
93+
94+
def test_word2vec_embedding_load_binary_format():
95+
test_dir = os.path.dirname(os.path.realpath(__file__))
96+
word2vec_vec = nlp.embedding.Word2Vec.from_file(
97+
os.path.join(str(test_dir), 'test_embedding', 'lorem_ipsum_w2v.vec'),
98+
elem_delim=' '
99+
)
100+
word2vec_bin = nlp.embedding.Word2Vec.from_w2v_binary(
101+
os.path.join(str(test_dir), 'test_embedding', 'lorem_ipsum_w2v.bin')
102+
)
103+
idx_to_vec = word2vec_bin[word2vec_vec.idx_to_token]
104+
assert np.all(
105+
np.isclose(a=word2vec_vec.idx_to_vec.asnumpy(),
106+
b=idx_to_vec.asnumpy(), atol=0.001))
107+
assert all(token in word2vec_bin for token in word2vec_vec.idx_to_token)
173 KB
Binary file not shown.

0 commit comments

Comments
 (0)