icichat.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import os
  2. import re
  3. from typing import Callable, List, Union
  4. from dotenv import dotenv_values, load_dotenv
  5. from utils.prompt import Prompt
  6. from langchain import SerpAPIWrapper
  7. from langchain.agents import Tool
  8. from langchain.vectorstores import FAISS
  9. from langchain.embeddings import OpenAIEmbeddings
  10. from langchain.schema import Document
  11. from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
  12. from langchain.prompts import StringPromptTemplate
  13. from langchain import OpenAI, SerpAPIWrapper, LLMChain
  14. from langchain.schema import AgentAction, AgentFinish
  15. from langchain.memory import ConversationBufferMemory
  16. from langchain.agents import initialize_agent
  17. from langchain.chat_models import ChatOpenAI
  18. from langchain.chains.conversation.memory import ConversationBufferWindowMemory
  19. from tools.ha import HALightControlTool
  20. load_dotenv()
  21. cli_prompt = Prompt('ici-bot')
  22. search = SerpAPIWrapper()
  23. ALL_TOOLS =[
  24. Tool(
  25. name = "Search",
  26. func=search.run,
  27. description=""" Useful for when you need to answer questions about current events on the internet
  28. Use it only when explicitly asked to search on internet.
  29. """
  30. ),
  31. HALightControlTool(),
  32. ]
  33. template = """Answer the following questions as best you can. You have access to the following tools:
  34. {tools}
  35. Use the following format:
  36. Question: the input question you must answer
  37. Thought: you should always think about what to do
  38. Action: the action to take, should be one of [{tool_names}]
  39. Action Input: the input to the action
  40. Observation: the result of the action
  41. ... (this Thought/Action/Action Input/Observation can repeat N times)
  42. Thought: I now know the final answer
  43. Final Answer: the final answer to the original input question
  44. Begin!
  45. Question: {input}
  46. {agent_scratchpad}"""
  47. class CustomPromptTemplate(StringPromptTemplate):
  48. template: str
  49. tools_getter: Callable
  50. def format(self, **kwargs) -> str:
  51. intermediate_steps = kwargs.pop("intermediate_steps")
  52. thoughts = ""
  53. for action, observation in intermediate_steps:
  54. thoughts += action.log
  55. thoughts += f"\nObservation: {observation}\nThought: "
  56. kwargs["agent_scratchpad"] = thoughts
  57. tools = self.tools_getter(kwargs["input"])
  58. kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
  59. kwargs["tool_names"] = ", ".join([tool.name for tool in tools])
  60. return self.template.format(**kwargs)
  61. class CustomOutputParser(AgentOutputParser):
  62. def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
  63. if "Final Answer:" in llm_output:
  64. return AgentFinish(
  65. return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
  66. log=llm_output,
  67. )
  68. regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
  69. match = re.search(regex, llm_output, re.DOTALL)
  70. if not match:
  71. raise ValueError(f"Could not parse LLM output: `{llm_output}`")
  72. action = match.group(1).strip()
  73. action_input = match.group(2)
  74. return AgentAction(
  75. tool=action,
  76. tool_input=action_input.strip(" ").strip('"'),
  77. log=llm_output
  78. )
  79. class ChatBot(object):
  80. def __init__(self,
  81. tools: List[Tool],
  82. model_name: str = "gpt-3.5-turbo",
  83. model_temperature: float = 0,
  84. model_max_tokens: int = 500,
  85. verbose: bool = False):
  86. self.sys_msg = """Assistant is a large language model trained by OpenAI.
  87. Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
  88. Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics.
  89. Overall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist.
  90. """
  91. self.tools = tools
  92. self.verbose = verbose
  93. docs = [
  94. Document(
  95. page_content=t.description,
  96. metadata={"index": i}) for i, t in enumerate(tools)
  97. ]
  98. vector_store = FAISS.from_documents(docs, OpenAIEmbeddings())
  99. self.retriever = vector_store.as_retriever()
  100. self.prompt = CustomPromptTemplate(
  101. template=template,
  102. tools_getter=self._get_tools,
  103. input_variables=["input", "intermediate_steps"]
  104. )
  105. self.output_parser = CustomOutputParser()
  106. self.llm = ChatOpenAI(
  107. temperature=model_temperature,
  108. max_tokens=model_max_tokens,
  109. model_name=model_name,
  110. )
  111. self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
  112. def _get_tools(self, query:str):
  113. docs = self.retriever.get_relevant_documents(query)
  114. return [self.tools[d.metadata["index"]] for d in docs]
  115. def _get_agent(self, input: str):
  116. memory = ConversationBufferWindowMemory(
  117. memory_key='chat_history',
  118. k=10,
  119. return_messages=True
  120. )
  121. tools = self._get_tools(input)
  122. tool_names = [tool.name for tool in tools]
  123. # agent = initialize_agent(
  124. # agent='chat-conversational-react-description',
  125. # system_message=self.sys_msg,
  126. # tools=self.tools,
  127. # llm=self.llm,
  128. # verbose=True,
  129. # max_iterations=3,
  130. # early_stopping_method='generate',
  131. # memory=memory,
  132. # output_parser=self.output_parser,
  133. # stop=["\nObservation:"],
  134. # allowed_tools=tool_names,
  135. # llm_chain=self.llm_chain,
  136. # )
  137. agent = LLMSingleActionAgent(
  138. llm_chain=self.llm_chain,
  139. output_parser=self.output_parser,
  140. stop=["\nObservation:"],
  141. allowed_tools=tool_names,
  142. memory=memory,
  143. max_iterations=3,
  144. )
  145. return agent
  146. def run(self, input: str):
  147. agent_executor = AgentExecutor.from_agent_and_tools(
  148. agent=self._get_agent(input),
  149. tools=self.tools,
  150. verbose=self.verbose
  151. )
  152. agent_executor.run(input)
  153. if __name__ == '__main__':
  154. bot = ChatBot(ALL_TOOLS, verbose=True)
  155. while True:
  156. input = cli_prompt.get().strip()
  157. if input == 'exit':
  158. break
  159. try:
  160. bot.run(input)
  161. except Exception as e:
  162. print("\n\nSystem:>{}\n".format(e))