@@ -448,7 +448,6 @@ def allow_extend(self):
448448 """
449449 return self ._allow_extend
450450
451-
452451 @property
453452 def unknown_lookup (self ):
454453 """Vector lookup for unknown tokens.
@@ -997,6 +996,8 @@ class Word2Vec(TokenEmbedding):
997996 ----------
998997 source : str, default 'GoogleNews-vectors-negative300'
999998 The name of the pre-trained token embedding file.
999+ A binary pre-trained file outside from the source list can be used for this constructor by
1000+ passing the path to it which ends with .bin as file extension name.
10001001 embedding_root : str, default '$MXNET_HOME/embedding'
10011002 The root directory for storing embedding-related files.
10021003 MXNET_HOME defaults to '~/.mxnet'.
@@ -1019,10 +1020,105 @@ class Word2Vec(TokenEmbedding):
10191020 source_file_hash = C .WORD2VEC_NPZ_SHA1
10201021
10211022 def __init__ (self , source = 'GoogleNews-vectors-negative300' ,
1022- embedding_root = os .path .join (get_home_dir (), 'embedding' ), ** kwargs ):
1023- self ._check_source (self .source_file_hash , source )
1024-
1023+ embedding_root = os .path .join (get_home_dir (), 'embedding' ), encoding = 'utf8' ,
1024+ ** kwargs ):
10251025 super (Word2Vec , self ).__init__ (** kwargs )
1026- pretrained_file_path = self ._get_file_path (self .source_file_hash , embedding_root , source )
1026+ if source .endswith ('.bin' ):
1027+ pretrained_file_path = os .path .expanduser (source )
1028+ self ._load_w2v_binary (pretrained_file_path , encoding = encoding )
1029+ else :
1030+ self ._check_source (self .source_file_hash , source )
1031+ pretrained_file_path = self ._get_file_path (self .source_file_hash ,
1032+ embedding_root , source )
1033+ self ._load_embedding (pretrained_file_path , elem_delim = ' ' )
10271034
1028- self ._load_embedding (pretrained_file_path , elem_delim = ' ' )
1035+ def _load_w2v_binary (self , pretrained_file_path , encoding = 'utf8' ):
1036+ """Load embedding vectors from a binary pre-trained token embedding file.
1037+
1038+ Parameters
1039+ ----------
1040+ pretrained_file_path: str
1041+ The path to a binary pre-trained token embedding file end with .bin as file extension
1042+ name.
1043+ encoding: str
1044+ The encoding type of the file.
1045+ """
1046+ self ._idx_to_token = [self .unknown_token ] if self .unknown_token else []
1047+ if self .unknown_token :
1048+ self ._token_to_idx = DefaultLookupDict (C .UNK_IDX )
1049+ else :
1050+ self ._token_to_idx = {}
1051+ self ._token_to_idx .update ((token , idx ) for idx , token in enumerate (self ._idx_to_token ))
1052+ self ._idx_to_vec = None
1053+ all_elems = []
1054+ tokens = set ()
1055+ loaded_unknown_vec = None
1056+ pretrained_file_path = os .path .expanduser (pretrained_file_path )
1057+ with io .open (pretrained_file_path , 'rb' ) as f :
1058+ header = f .readline ().decode (encoding = encoding )
1059+ vocab_size , vec_len = (int (x ) for x in header .split ())
1060+ if self .unknown_token :
1061+ # Reserve a vector slot for the unknown token at the very beggining
1062+ # because the unknown token index is 0.
1063+ all_elems .extend ([0 ] * vec_len )
1064+ binary_len = np .dtype (np .float32 ).itemsize * vec_len
1065+ for line_num in range (vocab_size ):
1066+ token = []
1067+ while True :
1068+ ch = f .read (1 )
1069+ if ch == b' ' :
1070+ break
1071+ if ch == b'' :
1072+ raise EOFError ('unexpected end of input; is count incorrect or file '
1073+ 'otherwise damaged?' )
1074+ if ch != b'\n ' : # ignore newlines in front of words (some binary files have)
1075+ token .append (ch )
1076+ try :
1077+ token = b'' .join (token ).decode (encoding = encoding )
1078+ except ValueError :
1079+ warnings .warn ('line {} in {}: failed to decode. Skipping.'
1080+ .format (line_num , pretrained_file_path ))
1081+ continue
1082+ elems = np .fromstring (f .read (binary_len ), dtype = np .float32 )
1083+
1084+ assert len (elems ) > 1 , 'line {} in {}: unexpected data format.' .format (
1085+ line_num , pretrained_file_path )
1086+
1087+ if token == self .unknown_token and loaded_unknown_vec is None :
1088+ loaded_unknown_vec = elems
1089+ tokens .add (self .unknown_token )
1090+ elif token in tokens :
1091+ warnings .warn ('line {} in {}: duplicate embedding found for '
1092+ 'token "{}". Skipped.' .format (line_num , pretrained_file_path ,
1093+ token ))
1094+ else :
1095+ assert len (elems ) == vec_len , \
1096+ 'line {} in {}: found vector of inconsistent dimension for token ' \
1097+ '"{}". expected dim: {}, found: {}' .format (line_num ,
1098+ pretrained_file_path ,
1099+ token , vec_len , len (elems ))
1100+ all_elems .extend (elems )
1101+ self ._idx_to_token .append (token )
1102+ self ._token_to_idx [token ] = len (self ._idx_to_token ) - 1
1103+ tokens .add (token )
1104+ self ._idx_to_vec = nd .array (all_elems ).reshape ((- 1 , vec_len ))
1105+
1106+ if self .unknown_token :
1107+ if loaded_unknown_vec is None :
1108+ self ._idx_to_vec [C .UNK_IDX ] = self ._init_unknown_vec (shape = vec_len )
1109+ else :
1110+ self ._idx_to_vec [C .UNK_IDX ] = nd .array (loaded_unknown_vec )
1111+
1112+ @classmethod
1113+ def from_w2v_binary (cls , pretrained_file_path , encoding = 'utf8' ):
1114+ """Load embedding vectors from a binary pre-trained token embedding file.
1115+
1116+ Parameters
1117+ ----------
1118+ pretrained_file_path: str
1119+ The path to a binary pre-trained token embedding file end with .bin as file extension
1120+ name.
1121+ encoding: str
1122+ The encoding type of the file.
1123+ """
1124+ return cls (source = pretrained_file_path , encoding = encoding )
0 commit comments