{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser\n", "from langchain.prompts import StringPromptTemplate\n", "from langchain import OpenAI, SerpAPIWrapper, LLMChain\n", "from typing import List, Union\n", "from langchain.schema import AgentAction, AgentFinish\n", "import re" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Define which tools the agent can use to answer user queries\n", "search = SerpAPIWrapper()\n", "search_tool = Tool(\n", " name = \"Search\",\n", " func=search.run,\n", " description=\"useful for when you need to answer questions about current events\"\n", " )\n", "def fake_func(inp: str) -> str:\n", " return \"foo\"\n", "fake_tools = [\n", " Tool(\n", " name=f\"foo-{i}\", \n", " func=fake_func, \n", " description=f\"a silly function that you can use to get more information about the number {i}\"\n", " ) \n", " for i in range(99)\n", "]\n", "ALL_TOOLS = [search_tool] + fake_tools" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from langchain.vectorstores import FAISS\n", "from langchain.embeddings import OpenAIEmbeddings\n", "from langchain.schema import Document" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "docs = [Document(page_content=t.description, metadata={\"index\": i}) for i, t in enumerate(ALL_TOOLS)]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "vector_store = FAISS.from_documents(docs, OpenAIEmbeddings())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "retriever = vector_store.as_retriever()\n", "\n", "def get_tools(query):\n", " docs = retriever.get_relevant_documents(query)\n", " return [ALL_TOOLS[d.metadata[\"index\"]] for d in docs]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Tool(name='Search', description='useful for when you need to answer questions about current events', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, params={'engine': 'google', 'google_domain': 'google.com', 'gl': 'us', 'hl': 'en'}, serpapi_api_key='5bbefda35f0c52c3696d6dac8640d164fe0288d2def45b18f13ed2e364dc82a3', aiosession=None)>, coroutine=None),\n", " Tool(name='foo-95', description='a silly function that you can use to get more information about the number 95', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None),\n", " Tool(name='foo-12', description='a silly function that you can use to get more information about the number 12', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None),\n", " Tool(name='foo-85', description='a silly function that you can use to get more information about the number 85', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_tools(\"whats the weather?\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Tool(name='foo-13', description='a silly function that you can use to get more information about the number 13', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None),\n", " Tool(name='foo-12', description='a silly function that you can use to get more information about the number 12', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None),\n", " Tool(name='foo-14', description='a silly function that you can use to get more information about the number 14', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None),\n", " Tool(name='foo-11', description='a silly function that you can use to get more information about the number 11', args_schema=None, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_tools(\"whats the number 13?\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Set up the base template\n", "template = \"\"\"Answer the following questions as best you can. You have access to the following tools:\n", "\n", "{tools}\n", "\n", "Use the following format:\n", "\n", "Question: the input question you must answer\n", "Thought: you should always think about what to do\n", "Action: the action to take, should be one of [{tool_names}]\n", "Action Input: the input to the action\n", "Observation: the result of the action\n", "... (this Thought/Action/Action Input/Observation can repeat N times)\n", "Thought: I now know the final answer\n", "Final Answer: the final answer to the original input question\n", "\n", "Begin!\n", "\n", "Question: {input}\n", "{agent_scratchpad}\"\"\"" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from typing import Callable\n", "# Set up a prompt template\n", "class CustomPromptTemplate(StringPromptTemplate):\n", " # The template to use\n", " template: str\n", " ############## NEW ######################\n", " # The list of tools available\n", " tools_getter: Callable\n", " \n", " def format(self, **kwargs) -> str:\n", " # Get the intermediate steps (AgentAction, Observation tuples)\n", " # Format them in a particular way\n", " intermediate_steps = kwargs.pop(\"intermediate_steps\")\n", " thoughts = \"\"\n", " for action, observation in intermediate_steps:\n", " thoughts += action.log\n", " thoughts += f\"\\nObservation: {observation}\\nThought: \"\n", " # Set the agent_scratchpad variable to that value\n", " kwargs[\"agent_scratchpad\"] = thoughts\n", " ############## NEW ######################\n", " tools = self.tools_getter(kwargs[\"input\"])\n", " # Create a tools variable from the list of tools provided\n", " kwargs[\"tools\"] = \"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in tools])\n", " # Create a list of tool names for the tools provided\n", " kwargs[\"tool_names\"] = \", \".join([tool.name for tool in tools])\n", " return self.template.format(**kwargs)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "prompt = CustomPromptTemplate(\n", " template=template,\n", " tools_getter=get_tools,\n", " # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n", " # This includes the `intermediate_steps` variable because that is needed\n", " input_variables=[\"input\", \"intermediate_steps\"]\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class CustomOutputParser(AgentOutputParser):\n", " \n", " def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n", " # Check if agent should finish\n", " if \"Final Answer:\" in llm_output:\n", " return AgentFinish(\n", " # Return values is generally always a dictionary with a single `output` key\n", " # It is not recommended to try anything else at the moment :)\n", " return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n", " log=llm_output,\n", " )\n", " # Parse out the action and action input\n", " regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n", " match = re.search(regex, llm_output, re.DOTALL)\n", " if not match:\n", " raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n", " action = match.group(1).strip()\n", " action_input = match.group(2)\n", " # Return the action and action input\n", " return AgentAction(tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "output_parser = CustomOutputParser()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "llm = OpenAI(temperature=0)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# LLM chain consisting of the LLM and a prompt\n", "llm_chain = LLMChain(llm=llm, prompt=prompt)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "tools = get_tools(\"whats the weather?\")\n", "tool_names = [tool.name for tool in tools]\n", "agent = LLMSingleActionAgent(\n", " llm_chain=llm_chain, \n", " output_parser=output_parser,\n", " stop=[\"\\nObservation:\"], \n", " allowed_tools=tool_names\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3mThought: I need to find out the current weather in SF\n", "Action: Search\n", "Action Input: \"weather in SF\"\u001b[0m\n", "\n", "Observation:\u001b[36;1m\u001b[1;3m10 Day Weather-San Francisco, CA ; Thu 04 · 57°. 16%. WSW 17 mph. Overcast. Slight chance of a rain shower. High 57F. Winds WSW at 10 to 20 mph. ; Fri 05 · 59°. 24%.\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: The current weather in SF is 57° with a 16% chance of rain and winds WSW at 10 to 20 mph.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'The current weather in SF is 57° with a 16% chance of rain and winds WSW at 10 to 20 mph.'" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent_executor.run(\"What's the weather in SF?\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent_executor.run(\"What's the number 13?\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "openai", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }