| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 | 
							- from langchain_groq import ChatGroq
 
- import os
 
- import yfinance as yf
 
- import pandas as pd
 
- from langchain_core.tools import tool
 
- from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
 
- from datetime import date
 
- import pandas as pd
 
- import plotly.graph_objects as go
 
- @tool
 
- def get_stock_info(symbol, key):
 
-     '''Return the correct stock info value given the appropriate symbol and key. Infer valid key from the user prompt; it must be one of the following:
 
-     address1, city, state, zip, country, phone, website, industry, industryKey, industryDisp, sector, sectorKey, sectorDisp, longBusinessSummary, fullTimeEmployees, companyOfficers, auditRisk, boardRisk, compensationRisk, shareHolderRightsRisk, overallRisk, governanceEpochDate, compensationAsOfEpochDate, maxAge, priceHint, previousClose, open, dayLow, dayHigh, regularMarketPreviousClose, regularMarketOpen, regularMarketDayLow, regularMarketDayHigh, dividendRate, dividendYield, exDividendDate, beta, trailingPE, forwardPE, volume, regularMarketVolume, averageVolume, averageVolume10days, averageDailyVolume10Day, bid, ask, bidSize, askSize, marketCap, fiftyTwoWeekLow, fiftyTwoWeekHigh, priceToSalesTrailing12Months, fiftyDayAverage, twoHundredDayAverage, currency, enterpriseValue, profitMargins, floatShares, sharesOutstanding, sharesShort, sharesShortPriorMonth, sharesShortPreviousMonthDate, dateShortInterest, sharesPercentSharesOut, heldPercentInsiders, heldPercentInstitutions, shortRatio, shortPercentOfFloat, impliedSharesOutstanding, bookValue, priceToBook, lastFiscalYearEnd, nextFiscalYearEnd, mostRecentQuarter, earningsQuarterlyGrowth, netIncomeToCommon, trailingEps, forwardEps, pegRatio, enterpriseToRevenue, enterpriseToEbitda, 52WeekChange, SandP52WeekChange, lastDividendValue, lastDividendDate, exchange, quoteType, symbol, underlyingSymbol, shortName, longName, firstTradeDateEpochUtc, timeZoneFullName, timeZoneShortName, uuid, messageBoardId, gmtOffSetMilliseconds, currentPrice, targetHighPrice, targetLowPrice, targetMeanPrice, targetMedianPrice, recommendationMean, recommendationKey, numberOfAnalystOpinions, totalCash, totalCashPerShare, ebitda, totalDebt, quickRatio, currentRatio, totalRevenue, debtToEquity, revenuePerShare, returnOnAssets, returnOnEquity, freeCashflow, operatingCashflow, earningsGrowth, revenueGrowth, grossMargins, ebitdaMargins, operatingMargins, financialCurrency, trailingPegRatio
 
-     If asked generically for 'stock price', use currentPrice
 
-     '''
 
-     data = yf.Ticker(symbol)
 
-     stock_info = data.info
 
-     return stock_info[key]
 
- @tool
 
- def get_historical_price(symbol, start_date, end_date):
 
-     """
 
-     Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.
 
-     - symbol (str): Stock ticker symbol.
 
-     - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date
 
-     - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date
 
-     """
 
-     data = yf.Ticker(symbol)
 
-     hist = data.history(start=start_date, end=end_date)
 
-     hist = hist.reset_index()
 
-     hist[symbol] = hist['Close']
 
-     return hist[['Date', symbol]]
 
- def plot_price_over_time(historical_price_dfs):
 
-     '''
 
-     Plots the historical stock prices over time for the given DataFrames.
 
-     Parameters:
 
-     historical_price_dfs (list): List of DataFrames containing historical stock prices.
 
-     '''
 
-     full_df = pd.DataFrame(columns=['Date'])
 
-     for df in historical_price_dfs:
 
-         full_df = full_df.merge(df, on='Date', how='outer')
 
-     # Create a Plotly figure
 
-     fig = go.Figure()
 
-     # Dynamically add a trace for each stock symbol in the DataFrame
 
-     for column in full_df.columns[1:]:  # Skip the first column since it's the date
 
-         fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column))
 
-     # Update the layout to add titles and format axis labels
 
-     fig.update_layout(
 
-         title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]),
 
-         xaxis_title='Date',
 
-         yaxis_title='Stock Price (USD)',
 
-         yaxis_tickprefix='$',
 
-         yaxis_tickformat=',.2f',
 
-         xaxis=dict(
 
-             tickangle=-45,
 
-             nticks=20,
 
-             tickfont=dict(size=10),
 
-         ),
 
-         yaxis=dict(
 
-             showgrid=True,   # Enable y-axis grid lines
 
-             gridcolor='lightgrey',  # Set grid line color
 
-         ),
 
-         legend_title_text='Stock Symbol',
 
-         plot_bgcolor='gray',  # Set plot background to white
 
-         paper_bgcolor='gray',  # Set overall figure background to white
 
-         legend=dict(
 
-             bgcolor='gray',  # Optional: Set legend background to white
 
-             bordercolor='black'
 
-         )
 
-     )
 
-     # Show the figure
 
-     fig.write_image("plot.png")
 
- def call_functions(llm_with_tools, user_prompt):
 
-     '''
 
-     Call the functions to interact with the llm_with_tools using the given user_prompt.
 
-     This function processes the user input, invokes tools based on the input, performs necessary operations,
 
-     generates responses or messages, and plots historical stock prices over time.
 
-     Parameters:
 
-     llm_with_tools (ChatGroq): ChatGroq object containing the tools for interaction.
 
-     user_prompt (str): User input prompt.
 
-     Returns:
 
-     str: Contents of the invoked messages through llm_with_tools interaction.
 
-     '''
 
-     system_prompt = 'You are a helpful finance assistant that analyzes stocks and stock prices. Today is {today}'.format(today=date.today())
 
-     messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)]
 
-     ai_msg = llm_with_tools.invoke(messages)
 
-     messages.append(ai_msg)
 
-     historical_price_dfs = []
 
-     symbols = []
 
-     for tool_call in ai_msg.tool_calls:
 
-         selected_tool = {"get_stock_info": get_stock_info, "get_historical_price": get_historical_price}[tool_call["name"].lower()]
 
-         tool_output = selected_tool.invoke(tool_call["args"])
 
-         if tool_call['name'] == 'get_historical_price':
 
-             historical_price_dfs.append(tool_output)
 
-             symbols.append(tool_output.columns[1])
 
-         else:
 
-             messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))
 
-     if len(historical_price_dfs) > 0:
 
-         plot_price_over_time(historical_price_dfs)
 
-         symbols = ' and '.join(symbols)
 
-         messages.append(ToolMessage('Tell the user that a historical stock price chart for {symbols} been generated.'.format(symbols=symbols), tool_call_id=0))
 
-     return llm_with_tools.invoke(messages).content
 
- llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-70b-8192')
 
- tools = [get_stock_info, get_historical_price]
 
- llm_with_tools = llm.bind_tools(tools)
 
- while True:
 
-     # Get user input from the console
 
-     user_input = input("You: ")
 
-     response = call_functions(llm_with_tools, user_input)
 
-     print("Assistant:", response)
 
 
  |