nearest.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Simple tool for inspecting nearest neighbors and analogies."""
  17. import re
  18. import sys
  19. from getopt import GetoptError, getopt
  20. from vecs import Vecs
  21. try:
  22. opts, args = getopt(sys.argv[1:], 'v:e:', ['vocab=', 'embeddings='])
  23. except GetoptError, e:
  24. print >> sys.stderr, e
  25. sys.exit(2)
  26. opt_vocab = 'vocab.txt'
  27. opt_embeddings = None
  28. for o, a in opts:
  29. if o in ('-v', '--vocab'):
  30. opt_vocab = a
  31. if o in ('-e', '--embeddings'):
  32. opt_embeddings = a
  33. vecs = Vecs(opt_vocab, opt_embeddings)
  34. while True:
  35. sys.stdout.write('query> ')
  36. sys.stdout.flush()
  37. query = sys.stdin.readline().strip()
  38. if not query:
  39. break
  40. parts = re.split(r'\s+', query)
  41. if len(parts) == 1:
  42. res = vecs.neighbors(parts[0])
  43. elif len(parts) == 3:
  44. vs = [vecs.lookup(w) for w in parts]
  45. if any(v is None for v in vs):
  46. print 'not in vocabulary: %s' % (
  47. ', '.join(tok for tok, v in zip(parts, vs) if v is None))
  48. continue
  49. res = vecs.neighbors(vs[2] - vs[0] + vs[1])
  50. else:
  51. print 'use a single word to query neighbors, or three words for analogy'
  52. continue
  53. if not res:
  54. continue
  55. for word, sim in res[:20]:
  56. print '%0.4f: %s' % (sim, word)
  57. print