{ "cells": [ { "cell_type": "markdown", "id": "4a797514-61dc-4914-9d12-ce5c1a5287d9", "metadata": {}, "source": [ "# Function Calling with Llama 3 and LangChain" ] }, { "cell_type": "markdown", "id": "4718a993-d052-4289-8b79-bb89f7b99023", "metadata": {}, "source": [ "The tech world is abuzz with the release of [Meta's Llama 3](https://llama.meta.com/llama3/), and Groq is excited to serve this powerful model at industry-leading speeds! Llama 3 [excels at function calling](https://twitter.com/RickLamers/status/1781444639079145722), making it an ideal choice for any function calling application. This cookbook will guide you through using Llama 3 in conjunction with [Groq's LangChain integration](https://python.langchain.com/docs/integrations/chat/groq/) to leverage Yahoo Finance's [yfinance API](https://pypi.org/project/yfinance/) for real-time stock market analysis. We'll demonstrate how to write functions to call the yfinance API from a user prompt, enabling the LLM to provide relevant, real-time information on the stock market, answering a range of questions from users" ] }, { "cell_type": "markdown", "id": "48234f2c-e41a-4c6c-a268-351d0d944682", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "6e0fb2e5-db54-402d-a4fd-78239495f406", "metadata": {}, "outputs": [], "source": [ "from langchain_groq import ChatGroq\n", "import os\n", "import yfinance as yf\n", "import pandas as pd" ] }, { "cell_type": "markdown", "id": "d9394c9a-0a8a-4882-81bd-76ea5f657f82", "metadata": {}, "source": [ "As mentioned in the introduction, we will be using Meta's Llama 3-70B model for function calling in this notebook. We are also using LangChain's ```ChatGroq``` function to define our LLM and integrate it with additional LangChain tooling. Note that you will need a Groq API Key to proceed and can create an account [here](https://console.groq.com/) to generate one for free." ] }, { "cell_type": "code", "execution_count": 2, "id": "c25157c7-2162-44f7-be8c-4f03d630f580", "metadata": {}, "outputs": [], "source": [ "llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-70b-8192')" ] }, { "cell_type": "markdown", "id": "0c91d792-cc1c-4956-9c18-3d44d1089e62", "metadata": {}, "source": [ "### Defining Tools" ] }, { "cell_type": "markdown", "id": "ecba8a3e-6590-415f-a9ea-d14ac77f9e99", "metadata": {}, "source": [ "Now we will define two [LangChain tools](https://python.langchain.com/docs/modules/tools/) that leverage the yfinance API to answer user queries. Our goal is to enable the LLM to provide accurate and timely information on any stock, just like you'd get on [Yahoo Finance](https://finance.yahoo.com/quote/META/). We'll focus on two types of information: current data, such as price, volume, and beta, and historical prices. To achieve this, we'll create two tools: ```get_stock_info``` for current information and ```get_historical_price``` for historical prices." ] }, { "cell_type": "markdown", "id": "d0ec46bf-1374-4bf8-94e9-b450e2422e2d", "metadata": {}, "source": [ "Each tool includes a detailed description that helps the LLM determine which tool to use and which parameters to use. In ```get_stock_info```, we list all the keys available in data.info to ensure that Llama 3 selects the correct key verbatim. In ```get_historical_price```, we explicitly explain the purpose of start_date and end_date and provide guidance on how to fill them. In both functions, we've found that Llama 3 is capable of identifying the correct stock symbol given a company name without additional prompting." ] }, { "cell_type": "code", "execution_count": 3, "id": "7e5f4446-df38-4938-8ad0-1a54374c158a", "metadata": {}, "outputs": [], "source": [ "from langchain_core.tools import tool\n", "\n", "@tool\n", "def get_stock_info(symbol, key):\n", " '''Return the correct stock info value given the appropriate symbol and key. Infer valid key from the user prompt; it must be one of the following:\n", "\n", " address1, city, state, zip, country, phone, website, industry, industryKey, industryDisp, sector, sectorKey, sectorDisp, longBusinessSummary, fullTimeEmployees, companyOfficers, auditRisk, boardRisk, compensationRisk, shareHolderRightsRisk, overallRisk, governanceEpochDate, compensationAsOfEpochDate, maxAge, priceHint, previousClose, open, dayLow, dayHigh, regularMarketPreviousClose, regularMarketOpen, regularMarketDayLow, regularMarketDayHigh, dividendRate, dividendYield, exDividendDate, beta, trailingPE, forwardPE, volume, regularMarketVolume, averageVolume, averageVolume10days, averageDailyVolume10Day, bid, ask, bidSize, askSize, marketCap, fiftyTwoWeekLow, fiftyTwoWeekHigh, priceToSalesTrailing12Months, fiftyDayAverage, twoHundredDayAverage, currency, enterpriseValue, profitMargins, floatShares, sharesOutstanding, sharesShort, sharesShortPriorMonth, sharesShortPreviousMonthDate, dateShortInterest, sharesPercentSharesOut, heldPercentInsiders, heldPercentInstitutions, shortRatio, shortPercentOfFloat, impliedSharesOutstanding, bookValue, priceToBook, lastFiscalYearEnd, nextFiscalYearEnd, mostRecentQuarter, earningsQuarterlyGrowth, netIncomeToCommon, trailingEps, forwardEps, pegRatio, enterpriseToRevenue, enterpriseToEbitda, 52WeekChange, SandP52WeekChange, lastDividendValue, lastDividendDate, exchange, quoteType, symbol, underlyingSymbol, shortName, longName, firstTradeDateEpochUtc, timeZoneFullName, timeZoneShortName, uuid, messageBoardId, gmtOffSetMilliseconds, currentPrice, targetHighPrice, targetLowPrice, targetMeanPrice, targetMedianPrice, recommendationMean, recommendationKey, numberOfAnalystOpinions, totalCash, totalCashPerShare, ebitda, totalDebt, quickRatio, currentRatio, totalRevenue, debtToEquity, revenuePerShare, returnOnAssets, returnOnEquity, freeCashflow, operatingCashflow, earningsGrowth, revenueGrowth, grossMargins, ebitdaMargins, operatingMargins, financialCurrency, trailingPegRatio\n", " \n", " If asked generically for 'stock price', use currentPrice\n", " '''\n", " data = yf.Ticker(symbol)\n", " stock_info = data.info\n", " return stock_info[key]\n", "\n", "\n", "@tool\n", "def get_historical_price(symbol, start_date, end_date):\n", " \"\"\"\n", " Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.\n", " - symbol (str): Stock ticker symbol.\n", " - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date\n", " - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date\n", " \"\"\"\n", "\n", " data = yf.Ticker(symbol)\n", " hist = data.history(start=start_date, end=end_date)\n", " hist = hist.reset_index()\n", " hist[symbol] = hist['Close']\n", " return hist[['Date', symbol]]\n" ] }, { "cell_type": "markdown", "id": "077ba232-8485-4172-9291-ac4d592aec32", "metadata": {}, "source": [ "### Using our Tools" ] }, { "cell_type": "markdown", "id": "1b95ae71-cae9-49fb-8a0b-66385b79d9de", "metadata": {}, "source": [ "Now we will chain our tools together and bind them with our LLM so that they can be accessed:" ] }, { "cell_type": "code", "execution_count": 4, "id": "7f062ce7-e23d-42b3-b9d6-493adb0c7bf1", "metadata": {}, "outputs": [], "source": [ "tools = [get_stock_info, get_historical_price]\n", "llm_with_tools = llm.bind_tools(tools)" ] }, { "cell_type": "markdown", "id": "53d8c970-cea9-42ec-8b52-50462dd48218", "metadata": {}, "source": [ "Let's test our function calling with a few simple prompts:" ] }, { "cell_type": "code", "execution_count": 5, "id": "8b5d16d2-6b5f-489d-b0d0-f62d65aee0b0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'name': 'get_stock_info', 'args': {'symbol': 'META', 'key': 'marketCap'}, 'id': 'call_3xm9'}]\n", "[{'name': 'get_stock_info', 'args': {'symbol': 'AAPL', 'key': 'volume'}, 'id': 'call_2p2z'}, {'name': 'get_stock_info', 'args': {'symbol': 'MSFT', 'key': 'volume'}, 'id': 'call_hvp4'}]\n" ] } ], "source": [ "query1 = 'What is the market cap of Meta?'\n", "query2 = 'How does the volume of Apple compare to that of Microsoft?'\n", "\n", "print(llm_with_tools.invoke(query1).tool_calls)\n", "print(llm_with_tools.invoke(query2).tool_calls)" ] }, { "cell_type": "markdown", "id": "58a44315-b9d0-435b-a13d-369cef12aa4b", "metadata": {}, "source": [ "As you can see, in our first query we successfully called ```get_stock_info``` with parameters **META** and **marketCap**, which are valid stock symbols and keys, respectively. In our second query, the LLM correctly called ```get_stock_info``` twice for Apple and Microsoft." ] }, { "cell_type": "code", "execution_count": 6, "id": "be4d79d9-6a4e-4b7e-9cca-1f3a21b5355e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'name': 'get_historical_price', 'args': {'symbol': '^GSPC', 'start_date': '2021-04-23', 'end_date': '2024-04-23'}, 'id': 'call_k06n'}]\n", "[{'name': 'get_historical_price', 'args': {'symbol': 'GOOGL', 'start_date': '2023-01-01', 'end_date': '2023-12-31'}, 'id': 'call_ca9y'}, {'name': 'get_historical_price', 'args': {'symbol': 'AMZN', 'start_date': '2023-01-01', 'end_date': '2023-12-31'}, 'id': 'call_h6q6'}]\n" ] } ], "source": [ "query1 = 'Show the historical price of the S&P 500 over the past 3 years? (Today is 4/23/2024)'\n", "query2 = 'Compare the price of Google and Amazon throughout 2023'\n", "\n", "print(llm_with_tools.invoke(query1).tool_calls)\n", "print(llm_with_tools.invoke(query2).tool_calls)" ] }, { "cell_type": "markdown", "id": "ead6b037-6dad-4781-bb38-8214fa831233", "metadata": {}, "source": [ "Our tool calling LLM also correctly identified ```get_historical_price``` for historical price questions, and appropriately called it twice. Note that to perform any kind of lookback analysis, you'll need to provide the current date." ] }, { "cell_type": "markdown", "id": "104f013a-0279-4ead-bcec-db7cd23eecd3", "metadata": {}, "source": [ "### Putting it all together" ] }, { "cell_type": "markdown", "id": "6e7fb035-4332-4be2-a333-91cca36e8696", "metadata": {}, "source": [ "This function, ```plot_price_over_time```, is not called by the LLM but will plot historical price over time if ```get_historical_price``` is called:" ] }, { "cell_type": "code", "execution_count": 11, "id": "ba207220-8e67-4799-bd32-576cd783132b", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import plotly.graph_objects as go\n", "\n", "def plot_price_over_time(historical_price_dfs):\n", "\n", " full_df = pd.DataFrame(columns = ['Date'])\n", " for df in historical_price_dfs:\n", " full_df = full_df.merge(df, on = 'Date', how = 'outer')\n", "\n", " # Create a Plotly figure\n", " fig = go.Figure()\n", " \n", " # Dynamically add a trace for each stock symbol in the DataFrame\n", " for column in full_df.columns[1:]: # Skip the first column since it's the date\n", " fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column))\n", " \n", " \n", " # Update the layout to add titles and format axis labels\n", " fig.update_layout(\n", " title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]),\n", " xaxis_title='Date',\n", " yaxis_title='Stock Price (USD)',\n", " yaxis_tickprefix='$',\n", " yaxis_tickformat=',.2f',\n", " xaxis=dict(\n", " tickangle=-45,\n", " nticks=20,\n", " tickfont=dict(size=10),\n", " ),\n", " yaxis=dict(\n", " showgrid=True, # Enable y-axis grid lines\n", " gridcolor='lightgrey', # Set grid line color\n", " ),\n", " legend_title_text='Stock Symbol',\n", " plot_bgcolor='white', # Set plot background to white\n", " paper_bgcolor='white', # Set overall figure background to white\n", " legend=dict(\n", " bgcolor='white', # Optional: Set legend background to white\n", " bordercolor='black'\n", " )\n", " )\n", " \n", " # Show the figure - unfortunately dynamic charts are not supported on GitHub preview, so this just generates\n", " # a static .png. If running locally, you can use fig.show(renderer='iframe') to output a dynamic plotly plot\n", " fig.show('png')\n" ] }, { "cell_type": "markdown", "id": "dd06265b-ceaf-4dca-b392-ceb3e10634f3", "metadata": {}, "source": [ "Finally, we will use LangChain to tie everything together. Our system prompt will provide the current date for context, and our function will execute each subsequent tool that's been called. It will also send the output back to the LLM so that it can respond to the user prompt with relevant information, and plot historical prices if that's what was asked for:" ] }, { "cell_type": "code", "execution_count": 8, "id": "f0312977-514a-4ecb-b1ed-29923c546704", "metadata": {}, "outputs": [], "source": [ "from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage\n", "from datetime import date\n", "\n", "def call_functions(llm_with_tools, user_prompt):\n", " system_prompt = 'You are a helpful finance assistant that analyzes stocks and stock prices. Today is {today}'.format(today = date.today())\n", " \n", " messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)]\n", " ai_msg = llm_with_tools.invoke(messages)\n", " messages.append(ai_msg)\n", " historical_price_dfs = []\n", " symbols = []\n", " for tool_call in ai_msg.tool_calls:\n", " selected_tool = {\"get_stock_info\": get_stock_info, \"get_historical_price\": get_historical_price}[tool_call[\"name\"].lower()]\n", " tool_output = selected_tool.invoke(tool_call[\"args\"])\n", " if tool_call['name'] == 'get_historical_price':\n", " historical_price_dfs.append(tool_output)\n", " symbols.append(tool_output.columns[1])\n", " else:\n", " messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n", " \n", " if len(historical_price_dfs) > 0:\n", " plot_price_over_time(historical_price_dfs)\n", " symbols = ' and '.join(symbols)\n", " messages.append(ToolMessage('Tell the user that a historical stock price chart for {symbols} been generated.'.format(symbols=symbols), tool_call_id=0))\n", "\n", " return llm_with_tools.invoke(messages).content\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "e35c1d41-72c0-4ec7-b25a-e68eec646861", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'The beta for Meta stock is 1.184.'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "user_prompt = 'What is the beta for meta stock?'\n", "call_functions(llm_with_tools, user_prompt)" ] }, { "cell_type": "code", "execution_count": 10, "id": "fff68338-24fa-454c-9a5e-8307c4341208", "metadata": {}, "outputs": [ { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'A historical stock price chart for GOOGL and AAPL and META has been generated.'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "user_prompt = \"Compare the stock price of Google, Apple and Meta over the past 6 months\"\n", "call_functions(llm_with_tools, user_prompt)" ] }, { "cell_type": "markdown", "id": "d21b4491-7be4-46d9-af10-8dab353312bc", "metadata": {}, "source": [ "### Conclusion" ] }, { "cell_type": "markdown", "id": "73573bc0-281a-419a-8465-9a1e179f7497", "metadata": {}, "source": [ "In this notebook, we've demonstrated how to harness the power of Groq API's function calling with Llama 3 and LangChain integration. Llama 3 is an impressive new model, and its capabilities are amplified when combined with Groq's exceptional LPU speed! To explore the interactive app that accompanies this notebook, please visit: https://llama3-function-calling.streamlit.app/" ] }, { "cell_type": "code", "execution_count": null, "id": "a911ee75-e50e-4310-b5ea-5e6aba5ee522", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }