123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- from langchain_groq import ChatGroq
- import os
- import yfinance as yf
- import pandas as pd
- from langchain_core.tools import tool
- from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
- from datetime import date
- import pandas as pd
- import plotly.graph_objects as go
- @tool
- def get_stock_info(symbol, key):
- '''Return the correct stock info value given the appropriate symbol and key. Infer valid key from the user prompt; it must be one of the following:
- address1, city, state, zip, country, phone, website, industry, industryKey, industryDisp, sector, sectorKey, sectorDisp, longBusinessSummary, fullTimeEmployees, companyOfficers, auditRisk, boardRisk, compensationRisk, shareHolderRightsRisk, overallRisk, governanceEpochDate, compensationAsOfEpochDate, maxAge, priceHint, previousClose, open, dayLow, dayHigh, regularMarketPreviousClose, regularMarketOpen, regularMarketDayLow, regularMarketDayHigh, dividendRate, dividendYield, exDividendDate, beta, trailingPE, forwardPE, volume, regularMarketVolume, averageVolume, averageVolume10days, averageDailyVolume10Day, bid, ask, bidSize, askSize, marketCap, fiftyTwoWeekLow, fiftyTwoWeekHigh, priceToSalesTrailing12Months, fiftyDayAverage, twoHundredDayAverage, currency, enterpriseValue, profitMargins, floatShares, sharesOutstanding, sharesShort, sharesShortPriorMonth, sharesShortPreviousMonthDate, dateShortInterest, sharesPercentSharesOut, heldPercentInsiders, heldPercentInstitutions, shortRatio, shortPercentOfFloat, impliedSharesOutstanding, bookValue, priceToBook, lastFiscalYearEnd, nextFiscalYearEnd, mostRecentQuarter, earningsQuarterlyGrowth, netIncomeToCommon, trailingEps, forwardEps, pegRatio, enterpriseToRevenue, enterpriseToEbitda, 52WeekChange, SandP52WeekChange, lastDividendValue, lastDividendDate, exchange, quoteType, symbol, underlyingSymbol, shortName, longName, firstTradeDateEpochUtc, timeZoneFullName, timeZoneShortName, uuid, messageBoardId, gmtOffSetMilliseconds, currentPrice, targetHighPrice, targetLowPrice, targetMeanPrice, targetMedianPrice, recommendationMean, recommendationKey, numberOfAnalystOpinions, totalCash, totalCashPerShare, ebitda, totalDebt, quickRatio, currentRatio, totalRevenue, debtToEquity, revenuePerShare, returnOnAssets, returnOnEquity, freeCashflow, operatingCashflow, earningsGrowth, revenueGrowth, grossMargins, ebitdaMargins, operatingMargins, financialCurrency, trailingPegRatio
- If asked generically for 'stock price', use currentPrice
- '''
- data = yf.Ticker(symbol)
- stock_info = data.info
- return stock_info[key]
- @tool
- def get_historical_price(symbol, start_date, end_date):
- """
- Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.
- - symbol (str): Stock ticker symbol.
- - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date
- - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date
- """
- data = yf.Ticker(symbol)
- hist = data.history(start=start_date, end=end_date)
- hist = hist.reset_index()
- hist[symbol] = hist['Close']
- return hist[['Date', symbol]]
- def plot_price_over_time(historical_price_dfs):
- '''
- Plots the historical stock prices over time for the given DataFrames.
- Parameters:
- historical_price_dfs (list): List of DataFrames containing historical stock prices.
- '''
- full_df = pd.DataFrame(columns=['Date'])
- for df in historical_price_dfs:
- full_df = full_df.merge(df, on='Date', how='outer')
- # Create a Plotly figure
- fig = go.Figure()
- # Dynamically add a trace for each stock symbol in the DataFrame
- for column in full_df.columns[1:]: # Skip the first column since it's the date
- fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column))
- # Update the layout to add titles and format axis labels
- fig.update_layout(
- title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]),
- xaxis_title='Date',
- yaxis_title='Stock Price (USD)',
- yaxis_tickprefix='$',
- yaxis_tickformat=',.2f',
- xaxis=dict(
- tickangle=-45,
- nticks=20,
- tickfont=dict(size=10),
- ),
- yaxis=dict(
- showgrid=True, # Enable y-axis grid lines
- gridcolor='lightgrey', # Set grid line color
- ),
- legend_title_text='Stock Symbol',
- plot_bgcolor='gray', # Set plot background to white
- paper_bgcolor='gray', # Set overall figure background to white
- legend=dict(
- bgcolor='gray', # Optional: Set legend background to white
- bordercolor='black'
- )
- )
- # Show the figure
- fig.write_image("plot.png")
- def call_functions(llm_with_tools, user_prompt):
- '''
- Call the functions to interact with the llm_with_tools using the given user_prompt.
- This function processes the user input, invokes tools based on the input, performs necessary operations,
- generates responses or messages, and plots historical stock prices over time.
- Parameters:
- llm_with_tools (ChatGroq): ChatGroq object containing the tools for interaction.
- user_prompt (str): User input prompt.
- Returns:
- str: Contents of the invoked messages through llm_with_tools interaction.
- '''
- system_prompt = 'You are a helpful finance assistant that analyzes stocks and stock prices. Today is {today}'.format(today=date.today())
- messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)]
- ai_msg = llm_with_tools.invoke(messages)
- messages.append(ai_msg)
- historical_price_dfs = []
- symbols = []
- for tool_call in ai_msg.tool_calls:
- selected_tool = {"get_stock_info": get_stock_info, "get_historical_price": get_historical_price}[tool_call["name"].lower()]
- tool_output = selected_tool.invoke(tool_call["args"])
- if tool_call['name'] == 'get_historical_price':
- historical_price_dfs.append(tool_output)
- symbols.append(tool_output.columns[1])
- else:
- messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))
- if len(historical_price_dfs) > 0:
- plot_price_over_time(historical_price_dfs)
- symbols = ' and '.join(symbols)
- messages.append(ToolMessage('Tell the user that a historical stock price chart for {symbols} been generated.'.format(symbols=symbols), tool_call_id=0))
- return llm_with_tools.invoke(messages).content
- llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-70b-8192')
- tools = [get_stock_info, get_historical_price]
- llm_with_tools = llm.bind_tools(tools)
- while True:
- # Get user input from the console
- user_input = input("You: ")
- response = call_functions(llm_with_tools, user_input)
- print("Assistant:", response)
|