chat_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import pandas as pd
  2. from openai import OpenAI
  3. import requests
  4. import json
  5. import re
  6. from pic_read import pic_see
  7. from config import config_get,key_get
  8. import os
  9. # 全局变量在这里改写:
  10. max_length = 20
  11. max_pic_see = 1
  12. prompt_pic = config_get("prompt_pic")
  13. prompt_lora = config_get("prompt_lora")
  14. pro_good = "solo,cat ears,black_hair,black_eye,"
  15. pro_bad = config_get("pro_bad")
  16. openrouter_key = key_get("openrouter")
  17. groq_key = key_get("groq")
  18. siliflow_key = key_get("siliflow")
  19. def AI_chat(group_id ,word ,prompt):
  20. """
  21. 猫猫人格,输入群号获取聊天记录,输入问题获取对话对象,返回字符串
  22. """
  23. # 加载聊天记录
  24. folder_path = '群聊记录'
  25. file_path = os.path.join(folder_path, str(group_id) + '.csv')
  26. df = pd.read_csv(file_path)
  27. recent = str("")
  28. recent_list = []
  29. for i in range(len(df)):
  30. if i >= len(df ) -max_length:
  31. recent_list.append(str(df.loc[i ,'user_name']))
  32. recent_list.append(str(df.loc[i, 'message']))
  33. pic_num = 0
  34. list_num = len(recent_list)
  35. for i in reversed(recent_list):
  36. list_num = list_num -1
  37. if pic_num >= max_pic_see:
  38. break
  39. if str(i).startswith("图片内容:"):
  40. url = i[5:]
  41. recent_list[list_num] = "发送了一张图片," + str(pic_see(url)).replace("\n", "")
  42. pic_num +=1
  43. for i in range(len(recent_list)):
  44. if i% 2 == 0:
  45. recent += recent_list[i] + ":"
  46. else:
  47. recent += recent_list[i] + "\n"
  48. print("-------------recent---------------")
  49. print(recent)
  50. messages = [{'role': 'system', 'content': prompt + recent}]
  51. messages.append({'role': 'user', 'content': word})
  52. try:
  53. ans = groq_chat(messages)
  54. except Exception as e:
  55. ans = chat_deepseek_r1(messages)
  56. ans = ans.lstrip()
  57. ans = ans.lstrip("猫猫:")
  58. while True:
  59. if ans.endswith("\n") or ans.endswith(" "):
  60. ans = ans[:-1]
  61. else:
  62. break
  63. return ans
  64. def chat_deepseek_r1(messages):
  65. """
  66. 备用对话通道,https://openrouter.ai/api/v1
  67. """
  68. client = OpenAI(
  69. base_url="https://openrouter.ai/api/v1",
  70. api_key=openrouter_key,
  71. )
  72. model2 = "deepseek/deepseek-r1-distill-llama-70b:free"
  73. model = "deepseek/deepseek-r1:free"
  74. completion = client.chat.completions.create(
  75. extra_headers={
  76. "HTTP-Referer": "<YOUR_SITE_URL>", # Optional. Site URL for rankings on openrouter.ai.
  77. "X-Title": "<YOUR_SITE_NAME>", # Optional. Site title for rankings on openrouter.ai.
  78. },
  79. model=model2,
  80. stream=False,
  81. messages=messages
  82. )
  83. ans = completion.choices[0].message.content
  84. return ans
  85. def groq_chat(messages):
  86. """
  87. 主对话通道,https://api.gxx12138.space/groq/v1
  88. """
  89. client = OpenAI(
  90. base_url="https://api.gxx12138.space/groq/v1",
  91. api_key=groq_key
  92. )
  93. model = "deepseek-r1-distill-llama-70b"
  94. completion = client.chat.completions.create(
  95. model=model,
  96. stream=False,
  97. messages=messages
  98. )
  99. ans = completion.choices[0].message.content
  100. return ans
  101. def siliflow_chat(messages):
  102. """
  103. 备用对话通道,https://api.siliconflow.cn/v1'
  104. """
  105. client = OpenAI(
  106. base_url='https://api.siliconflow.cn/v1',
  107. api_key=siliflow_key
  108. )
  109. try:
  110. print("硅基模型正常运作")
  111. # 发送AI请求
  112. # THUDM/glm-4-9b-chat google/gemma-2-9b-it
  113. response = client.chat.completions.create(
  114. model="THUDM/glm-4-9b-chat",
  115. messages=messages,
  116. stream=False,
  117. temperature=0.8,
  118. )
  119. ans = response.choices[0].message.content
  120. except Exception as e:
  121. print("硅基模型爆了")
  122. ans = ""
  123. return ans
  124. # 根据群聊内近4条消息获取绘画提示词
  125. def AI_get_picprompt(group_id):
  126. """
  127. 群聊id自动获取群聊记录来获取绘画提示词
  128. """
  129. global pro_good
  130. # 加载聊天记录
  131. folder_path = '群聊记录'
  132. file_path = os.path.join(folder_path, str(group_id) + '.csv')
  133. df = pd.read_csv(file_path)
  134. recent = ""
  135. for i in range(len(df) - 4, len(df)):
  136. recent += str(df.loc[i, 'user_name']) + ":" + str(df.loc[i, 'message']) + "\n"
  137. messages = [{'role': 'system', 'content': prompt_pic}]
  138. messages.append({'role': 'user', 'content': recent})
  139. ans = groq_chat(messages)
  140. return pro_good + ans.strip().replace("\n", "") + ","
  141. def AI_lora_getpic_prompt(word):
  142. """
  143. 群聊id自动获取群聊记录来获取绘画提示词
  144. 版本2
  145. """
  146. messages = [{'role': 'system', 'content': prompt_lora}]
  147. messages.append({'role': 'user', 'content': word})
  148. ans = siliflow_chat(messages).replace(", ", ",")
  149. return ans.strip().replace("\n", "") + ","