ExampleLlamaRemoteInference.kt 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package com.example.llamaandroiddemo
  2. import android.content.ContentResolver
  3. import android.content.Context
  4. import android.database.Cursor
  5. import android.net.Uri
  6. import android.provider.MediaStore
  7. import android.util.Base64
  8. import android.util.Log
  9. import okhttp3.MediaType.Companion.toMediaType
  10. import okhttp3.OkHttpClient
  11. import okhttp3.Request
  12. import okhttp3.RequestBody
  13. import okhttp3.RequestBody.Companion.toRequestBody
  14. import okhttp3.Response
  15. import okio.IOException
  16. import org.json.JSONObject
  17. import java.io.File
  18. import java.net.URLConnection
  19. import java.util.concurrent.CompletableFuture
  20. import java.util.concurrent.TimeUnit
  21. import androidx.core.net.toUri
  22. interface InferenceStreamingCallback {
  23. fun onStreamReceived(message: String)
  24. fun onStatStreamReceived(tps: Float)
  25. }
  26. class ExampleLlamaRemoteInference(remoteURL: String) {
  27. var remoteURL: String = ""
  28. init {
  29. this.remoteURL = remoteURL
  30. }
  31. fun inferenceStartWithoutAgent(modelName: String, temperature: Double, prompt: ArrayList<Message>, userProvidedSystemPrompt:String, ctx: Context): String {
  32. val future = CompletableFuture<String>()
  33. val thread = Thread {
  34. try {
  35. val response = inferenceCallWithoutAgent(modelName, temperature, prompt, userProvidedSystemPrompt, ctx, true);
  36. future.complete(response)
  37. } catch (e: Exception) {
  38. e.printStackTrace()
  39. }
  40. }
  41. thread.start();
  42. return future.get();
  43. }
  44. fun makeStreamingPostRequest(url: String, requestBody: RequestBody): Response {
  45. val client = OkHttpClient.Builder()
  46. .readTimeout(0, TimeUnit.MILLISECONDS) // No timeout for streaming
  47. .build()
  48. if (AppUtils.API_KEY == "") {
  49. AppLogging.getInstance().log("API key not set, configure it in AppUtils")
  50. }
  51. val request = Request.Builder()
  52. .url(url)
  53. .addHeader("Authorization","Bearer " + AppUtils.API_KEY)
  54. .addHeader("Accept", "text/event-stream")
  55. .post(requestBody)
  56. .build()
  57. return client.newCall(request).execute()
  58. }
  59. private fun llamaChatCompletion(ctx: Context, modelName: String, conversationHistory: ArrayList<Message>, userProvidedSystemPrompt: String, temperature: Double){
  60. var msg = """
  61. {
  62. "role": "system",
  63. "content": "$userProvidedSystemPrompt"
  64. },
  65. """.trimIndent()
  66. msg += constructMessageForAPICall(conversationHistory,ctx)
  67. val thread = Thread {
  68. try {
  69. // Create request body
  70. val json = """
  71. {
  72. "messages": [
  73. $msg
  74. ],
  75. "model": "$modelName",
  76. "repetition_penalty": 1,
  77. "temperature": $temperature,
  78. "top_p": 0.9,
  79. "max_completion_tokens": 2048,
  80. "stream": true
  81. }""".trimIndent()
  82. val requestBody = json.toRequestBody("application/json".toMediaType())
  83. // Make request
  84. val response = makeStreamingPostRequest("$remoteURL/v1/chat/completions", requestBody)
  85. val callback = ctx as InferenceStreamingCallback
  86. // Process streaming response
  87. response.use { res ->
  88. if (!res.isSuccessful) throw IOException("Unexpected code $res")
  89. res.body?.source()?.let { source ->
  90. while (!source.exhausted()) {
  91. val streamDelta = source.readUtf8Line()
  92. if (streamDelta != null){
  93. val jsonString = streamDelta.substringAfter("data: ")
  94. if (jsonString != ""){
  95. val obj = JSONObject(jsonString)
  96. if (obj.has("choices")) {
  97. val choices = obj.getJSONArray("choices")
  98. if (choices.length() > 0) {
  99. val choice = choices.getJSONObject(0)
  100. if (choice.has("delta")) {
  101. val delta = choice.getJSONObject("delta")
  102. if (delta.has("content")) {
  103. val result = delta.getString("content")
  104. callback.onStreamReceived(result)
  105. }
  106. } else if (choice.has("text")) {
  107. val result = choice.getString("text")
  108. callback.onStreamReceived(result)
  109. }
  110. }
  111. }
  112. }
  113. }
  114. }
  115. }
  116. }
  117. } catch (e: Exception) {
  118. Log.d("error",e.message.toString())
  119. e.printStackTrace()
  120. }
  121. }
  122. thread.start();
  123. }
  124. //Example running simple inference + tool calls without using agent's workflow
  125. private fun inferenceCallWithoutAgent(modelName: String, temperature: Double, conversationHistory: ArrayList<Message>, userProvidedSystemPrompt: String, ctx: Context, streaming: Boolean): String {
  126. llamaChatCompletion(ctx,modelName,conversationHistory, userProvidedSystemPrompt, temperature)
  127. return ""
  128. }
  129. private fun encodeImageToDataUrl(filePath: String): String {
  130. val mimeType = URLConnection.guessContentTypeFromName(filePath)
  131. ?: throw RuntimeException("Could not determine MIME type of the file")
  132. val imageFile = File(filePath)
  133. val encodedString = Base64.encodeToString(imageFile.readBytes(), Base64.NO_WRAP)
  134. return "data:image/jpeg;base64,$encodedString"
  135. }
  136. private fun getFilePathFromUri(contentResolver: ContentResolver, uri: Uri): String? {
  137. var filePath: String? = null
  138. val projection = arrayOf(MediaStore.Images.Media.DATA)
  139. val cursor: Cursor? = contentResolver.query(uri, projection, null, null, null)
  140. cursor?.use {
  141. if (it.moveToFirst()) {
  142. val columnIndex = it.getColumnIndexOrThrow(MediaStore.Images.Media.DATA)
  143. filePath = it.getString(columnIndex)
  144. }
  145. }
  146. return filePath
  147. }
  148. private fun constructMessageForAPICall(
  149. conversationHistory: ArrayList<Message>,
  150. ctx: Context
  151. ):String {
  152. var prompt = ""
  153. var isPreviousChatImage = false
  154. val imagePromptList = ArrayList<String>()
  155. for ((index, chat) in conversationHistory.withIndex()) {
  156. if (chat.isSent) {
  157. // First image in the chat. Image must pair with a prompt
  158. if (chat.messageType == MessageType.IMAGE) {
  159. val imageUri = chat.imagePath.toUri()
  160. val contentResolver = ctx.contentResolver
  161. val imageFilePath = getFilePathFromUri(contentResolver, imageUri)
  162. val imageDataUrl = imageFilePath?.let { encodeImageToDataUrl(it) }
  163. Log.d("imageDataURL",imageDataUrl.toString())
  164. imagePromptList += """
  165. {
  166. "type": "image_url",
  167. "image_url": {
  168. "url": "$imageDataUrl"
  169. }
  170. }
  171. """.trimIndent()
  172. isPreviousChatImage = true
  173. continue
  174. }
  175. // Prompt right after the image
  176. else if (chat.messageType == MessageType.TEXT) {
  177. if (isPreviousChatImage) {
  178. var imagePrompts = ""
  179. for ((idx, image) in imagePromptList.withIndex()) {
  180. imagePrompts += image
  181. if (idx < imagePromptList.lastIndex) {
  182. imagePrompts += ","
  183. }
  184. }
  185. prompt += """
  186. {
  187. "role": "user",
  188. "content": [
  189. $imagePrompts,
  190. {
  191. "type": "text",
  192. "text": "${chat.text}"
  193. }
  194. ]
  195. }
  196. """.trimIndent()
  197. isPreviousChatImage = false
  198. } else {
  199. prompt += """
  200. {
  201. "role": "user",
  202. "content": "${chat.text}"
  203. }
  204. """.trimIndent()
  205. }
  206. }
  207. } else {
  208. // assistant message/response
  209. // only text response
  210. prompt += """
  211. { "role": "assistant",
  212. "content": ${JSONObject.quote(chat.text)}
  213. }
  214. """.trimIndent()
  215. }
  216. if (chat.messageType != MessageType.IMAGE && index != conversationHistory.lastIndex) {
  217. // This is NOT the last chat and not image
  218. prompt += ","
  219. }
  220. }
  221. Log.d("inference", "this is prompt: $prompt")
  222. return prompt
  223. }
  224. }