main.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from langchain_groq import ChatGroq
  2. import os
  3. import yfinance as yf
  4. import pandas as pd
  5. from langchain_core.tools import tool
  6. from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
  7. from datetime import date
  8. import pandas as pd
  9. import plotly.graph_objects as go
  10. @tool
  11. def get_stock_info(symbol, key):
  12. '''Return the correct stock info value given the appropriate symbol and key. Infer valid key from the user prompt; it must be one of the following:
  13. address1, city, state, zip, country, phone, website, industry, industryKey, industryDisp, sector, sectorKey, sectorDisp, longBusinessSummary, fullTimeEmployees, companyOfficers, auditRisk, boardRisk, compensationRisk, shareHolderRightsRisk, overallRisk, governanceEpochDate, compensationAsOfEpochDate, maxAge, priceHint, previousClose, open, dayLow, dayHigh, regularMarketPreviousClose, regularMarketOpen, regularMarketDayLow, regularMarketDayHigh, dividendRate, dividendYield, exDividendDate, beta, trailingPE, forwardPE, volume, regularMarketVolume, averageVolume, averageVolume10days, averageDailyVolume10Day, bid, ask, bidSize, askSize, marketCap, fiftyTwoWeekLow, fiftyTwoWeekHigh, priceToSalesTrailing12Months, fiftyDayAverage, twoHundredDayAverage, currency, enterpriseValue, profitMargins, floatShares, sharesOutstanding, sharesShort, sharesShortPriorMonth, sharesShortPreviousMonthDate, dateShortInterest, sharesPercentSharesOut, heldPercentInsiders, heldPercentInstitutions, shortRatio, shortPercentOfFloat, impliedSharesOutstanding, bookValue, priceToBook, lastFiscalYearEnd, nextFiscalYearEnd, mostRecentQuarter, earningsQuarterlyGrowth, netIncomeToCommon, trailingEps, forwardEps, pegRatio, enterpriseToRevenue, enterpriseToEbitda, 52WeekChange, SandP52WeekChange, lastDividendValue, lastDividendDate, exchange, quoteType, symbol, underlyingSymbol, shortName, longName, firstTradeDateEpochUtc, timeZoneFullName, timeZoneShortName, uuid, messageBoardId, gmtOffSetMilliseconds, currentPrice, targetHighPrice, targetLowPrice, targetMeanPrice, targetMedianPrice, recommendationMean, recommendationKey, numberOfAnalystOpinions, totalCash, totalCashPerShare, ebitda, totalDebt, quickRatio, currentRatio, totalRevenue, debtToEquity, revenuePerShare, returnOnAssets, returnOnEquity, freeCashflow, operatingCashflow, earningsGrowth, revenueGrowth, grossMargins, ebitdaMargins, operatingMargins, financialCurrency, trailingPegRatio
  14. If asked generically for 'stock price', use currentPrice
  15. '''
  16. data = yf.Ticker(symbol)
  17. stock_info = data.info
  18. return stock_info[key]
  19. @tool
  20. def get_historical_price(symbol, start_date, end_date):
  21. """
  22. Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'.
  23. - symbol (str): Stock ticker symbol.
  24. - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date
  25. - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date
  26. """
  27. data = yf.Ticker(symbol)
  28. hist = data.history(start=start_date, end=end_date)
  29. hist = hist.reset_index()
  30. hist[symbol] = hist['Close']
  31. return hist[['Date', symbol]]
  32. def plot_price_over_time(historical_price_dfs):
  33. '''
  34. Plots the historical stock prices over time for the given DataFrames.
  35. Parameters:
  36. historical_price_dfs (list): List of DataFrames containing historical stock prices.
  37. '''
  38. full_df = pd.DataFrame(columns=['Date'])
  39. for df in historical_price_dfs:
  40. full_df = full_df.merge(df, on='Date', how='outer')
  41. # Create a Plotly figure
  42. fig = go.Figure()
  43. # Dynamically add a trace for each stock symbol in the DataFrame
  44. for column in full_df.columns[1:]: # Skip the first column since it's the date
  45. fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column))
  46. # Update the layout to add titles and format axis labels
  47. fig.update_layout(
  48. title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]),
  49. xaxis_title='Date',
  50. yaxis_title='Stock Price (USD)',
  51. yaxis_tickprefix='$',
  52. yaxis_tickformat=',.2f',
  53. xaxis=dict(
  54. tickangle=-45,
  55. nticks=20,
  56. tickfont=dict(size=10),
  57. ),
  58. yaxis=dict(
  59. showgrid=True, # Enable y-axis grid lines
  60. gridcolor='lightgrey', # Set grid line color
  61. ),
  62. legend_title_text='Stock Symbol',
  63. plot_bgcolor='gray', # Set plot background to white
  64. paper_bgcolor='gray', # Set overall figure background to white
  65. legend=dict(
  66. bgcolor='gray', # Optional: Set legend background to white
  67. bordercolor='black'
  68. )
  69. )
  70. # Show the figure
  71. fig.write_image("plot.png")
  72. def call_functions(llm_with_tools, user_prompt):
  73. '''
  74. Call the functions to interact with the llm_with_tools using the given user_prompt.
  75. This function processes the user input, invokes tools based on the input, performs necessary operations,
  76. generates responses or messages, and plots historical stock prices over time.
  77. Parameters:
  78. llm_with_tools (ChatGroq): ChatGroq object containing the tools for interaction.
  79. user_prompt (str): User input prompt.
  80. Returns:
  81. str: Contents of the invoked messages through llm_with_tools interaction.
  82. '''
  83. system_prompt = 'You are a helpful finance assistant that analyzes stocks and stock prices. Today is {today}'.format(today=date.today())
  84. messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)]
  85. ai_msg = llm_with_tools.invoke(messages)
  86. messages.append(ai_msg)
  87. historical_price_dfs = []
  88. symbols = []
  89. for tool_call in ai_msg.tool_calls:
  90. selected_tool = {"get_stock_info": get_stock_info, "get_historical_price": get_historical_price}[tool_call["name"].lower()]
  91. tool_output = selected_tool.invoke(tool_call["args"])
  92. if tool_call['name'] == 'get_historical_price':
  93. historical_price_dfs.append(tool_output)
  94. symbols.append(tool_output.columns[1])
  95. else:
  96. messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))
  97. if len(historical_price_dfs) > 0:
  98. plot_price_over_time(historical_price_dfs)
  99. symbols = ' and '.join(symbols)
  100. messages.append(ToolMessage('Tell the user that a historical stock price chart for {symbols} been generated.'.format(symbols=symbols), tool_call_id=0))
  101. return llm_with_tools.invoke(messages).content
  102. llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-70b-8192')
  103. tools = [get_stock_info, get_historical_price]
  104. llm_with_tools = llm.bind_tools(tools)
  105. while True:
  106. # Get user input from the console
  107. user_input = input("You: ")
  108. response = call_functions(llm_with_tools, user_input)
  109. print("Assistant:", response)