vecs.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import mmap
  15. import numpy as np
  16. import os
  17. import struct
  18. class Vecs(object):
  19. def __init__(self, vocab_filename, rows_filename, cols_filename=None):
  20. """Initializes the vectors from a text vocabulary and binary data."""
  21. with open(vocab_filename, 'r') as lines:
  22. self.vocab = [line.split()[0] for line in lines]
  23. self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
  24. n = len(self.vocab)
  25. with open(rows_filename, 'r') as rows_fh:
  26. rows_fh.seek(0, os.SEEK_END)
  27. size = rows_fh.tell()
  28. # Make sure that the file size seems reasonable.
  29. if size % (4 * n) != 0:
  30. raise IOError(
  31. 'unexpected file size for binary vector file %s' % rows_filename)
  32. # Memory map the rows.
  33. dim = size / (4 * n)
  34. rows_mm = mmap.mmap(rows_fh.fileno(), 0, prot=mmap.PROT_READ)
  35. rows = np.matrix(
  36. np.frombuffer(rows_mm, dtype=np.float32).reshape(n, dim))
  37. # If column vectors were specified, then open them and add them to the row
  38. # vectors.
  39. if cols_filename:
  40. with open(cols_filename, 'r') as cols_fh:
  41. cols_mm = mmap.mmap(cols_fh.fileno(), 0, prot=mmap.PROT_READ)
  42. cols_fh.seek(0, os.SEEK_END)
  43. if cols_fh.tell() != size:
  44. raise IOError('row and column vector files have different sizes')
  45. cols = np.matrix(
  46. np.frombuffer(cols_mm, dtype=np.float32).reshape(n, dim))
  47. rows += cols
  48. cols_mm.close()
  49. # Normalize so that dot products are just cosine similarity.
  50. self.vecs = rows / np.linalg.norm(rows, axis=1).reshape(n, 1)
  51. rows_mm.close()
  52. def similarity(self, word1, word2):
  53. """Computes the similarity of two tokens."""
  54. idx1 = self.word_to_idx.get(word1)
  55. idx2 = self.word_to_idx.get(word2)
  56. if not idx1 or not idx2:
  57. return None
  58. return float(self.vecs[idx1] * self.vecs[idx2].transpose())
  59. def neighbors(self, query):
  60. """Returns the nearest neighbors to the query (a word or vector)."""
  61. if isinstance(query, basestring):
  62. idx = self.word_to_idx.get(query)
  63. if idx is None:
  64. return None
  65. query = self.vecs[idx]
  66. neighbors = self.vecs * query.transpose()
  67. return sorted(
  68. zip(self.vocab, neighbors.flat),
  69. key=lambda kv: kv[1], reverse=True)
  70. def lookup(self, word):
  71. """Returns the embedding for a token, or None if no embedding exists."""
  72. idx = self.word_to_idx.get(word)
  73. return None if idx is None else self.vecs[idx]