embedding.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. from transformers import AutoTokenizer, AutoModel
  3. from llama_index.core.base.embeddings.base import BaseEmbedding
  4. device = "cuda" if torch.cuda.is_available() else "cpu"
  5. # Load tokenizer and model
  6. model_id = "jinaai/jina-embeddings-v2-base-en" #"jinaai/jina-embeddings-v3"
  7. tokenizer = AutoTokenizer.from_pretrained(model_id)
  8. model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device)
  9. # Define function to generate embeddings
  10. def get_embedding(text):
  11. inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
  12. with torch.no_grad():
  13. outputs = model(**inputs)
  14. return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() #.to(torch.float32)
  15. class LocalJinaEmbedding(BaseEmbedding):
  16. def __init__(self):
  17. super().__init__()
  18. def _get_text_embedding(self, text):
  19. return get_embedding(text).tolist() # Ensure compatibility with LlamaIndex
  20. def _get_query_embedding(self, query):
  21. return get_embedding(query).tolist()
  22. async def _aget_query_embedding(self, query: str) -> list:
  23. return get_embedding(query).tolist()
  24. def test(): #this did not produce reasonable results for some reason
  25. #!pip install llama-index-embeddings-huggingface
  26. from llama_index.embeddings.huggingface import HuggingFaceEmbedding
  27. embed_model = HuggingFaceEmbedding(model_name=model_id)
  28. if __name__=="__main__":
  29. emb = get_embedding("hi there")
  30. print(emb.shape)