Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.ai.ondevice

import android.graphics.Bitmap
import com.google.firebase.ai.ondevice.interop.Candidate
import com.google.firebase.ai.ondevice.interop.CountTokensResponse
import com.google.firebase.ai.ondevice.interop.FinishReason
Expand All @@ -24,14 +25,15 @@ import com.google.firebase.ai.ondevice.interop.GenerateContentResponse
import com.google.mlkit.genai.prompt.GenerateContentRequest
import com.google.mlkit.genai.prompt.ImagePart
import com.google.mlkit.genai.prompt.TextPart
import kotlin.math.min

// ====================================
// `Part` converter extension functions
// ====================================
internal fun com.google.firebase.ai.ondevice.interop.TextPart.toMlKit(): TextPart = TextPart(text)

internal fun com.google.firebase.ai.ondevice.interop.ImagePart.toMlKit(): ImagePart =
ImagePart(bitmap)
ImagePart(downsizeBitmapIfNeeded(bitmap))

// ============================================
// `CountTokens*` converter extension functions
Expand Down Expand Up @@ -87,3 +89,24 @@ private fun generateContentRequest(
builder.init()
return builder.build()
}

private fun downsizeBitmapIfNeeded(bitmap: Bitmap): Bitmap {
val IMAGE_SHORTER_DIMENSION_MAX_VALUE: Int = 768
val width = bitmap.width
val height = bitmap.height
val shorterDimension: Int = min(width, height)
if (shorterDimension <= IMAGE_SHORTER_DIMENSION_MAX_VALUE) {
return bitmap
}

val scaleFactor = (IMAGE_SHORTER_DIMENSION_MAX_VALUE.toDouble()) / shorterDimension

val newWidth = (width * scaleFactor).toInt()
val newHeight = (height * scaleFactor).toInt()

val resizedBitmap = Bitmap.createScaledBitmap(bitmap, newWidth, newHeight, /* filter= */ false)
if (resizedBitmap != bitmap) {
bitmap.recycle()
}
return resizedBitmap
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import com.google.firebase.ai.type.GenerateContentResponse
import com.google.firebase.ai.type.GenerateObjectResponse
import com.google.firebase.ai.type.JsonSchema
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.onEach

/**
* A [GenerativeModelProvider] that delegates requests to a `defaultModel` and falls back to a
Expand Down Expand Up @@ -59,7 +62,7 @@ internal class FallbackGenerativeModelProvider(
}

override fun generateContentStream(prompt: List<Content>): Flow<GenerateContentResponse> {
return withFallback("generateContentStream") { generateContentStream(prompt) }
return withFlowFallback("generateContentStream") { generateContentStream(prompt) }
}

override suspend fun <T : Any> generateObject(
Expand Down Expand Up @@ -104,6 +107,49 @@ internal class FallbackGenerativeModelProvider(
}
}

// Flow-based fallback differs significantly from regular call fallback, primarily due to:
//
// 1. Exception Timing: Exceptions are thrown during *flow collection*, not at flow creation.
// Therefore, a *wrapper flow* is utilized to manage emitting either default or fallback
// elements.
// 2. Partial Collection Rule: If a flow collection has *successfully started* before an exception
// occurs, *no fallback* should be triggered. This prevents inconsistent or mixed responses,
// where partial data from one source could be combined with data from another.
private inline fun <T> withFlowFallback(
methodName: String,
crossinline block: GenerativeModelProvider.() -> Flow<T>
): Flow<T> {
if (!precondition()) {
Log.w(
TAG,
"Precondition was not met, switching to fallback model `${fallbackModel.javaClass.simpleName}`"
)
return fallbackModel.block()
}
return flow {
var hasEmitted = false
val defaultFlow = defaultModel.block().onEach { hasEmitted = true }
try {
emitAll(defaultFlow)
} catch (e: Exception) {
if (
!hasEmitted &&
shouldFallbackInException &&
(e is FirebaseAIException || e is FirebaseAIOnDeviceException)
) {
Log.w(
TAG,
"Error running `$methodName` on `${defaultModel.javaClass.simpleName}`. Falling back to `${fallbackModel.javaClass.simpleName}`",
e
)
emitAll(fallbackModel.block())
} else {
throw e
}
}
}
}

companion object {
private val TAG = FallbackGenerativeModelProvider::class.java.simpleName
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ public class FunctionDeclaration(
internal data class Internal(
val name: String,
val description: String,
val parameters: Schema.InternalOpenAPI?,
val parametersJsonSchema: Schema.InternalJson?,
val responseJsonSchema: Schema.InternalJson?,
val parameters: Schema.InternalOpenAPI? = null,
val parametersJsonSchema: Schema.InternalJson? = null,
val responseJsonSchema: Schema.InternalJson? = null,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@ package com.google.firebase.ai.generativemodel
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.CountTokensResponse
import com.google.firebase.ai.type.FirebaseAIException
import com.google.firebase.ai.type.FirebaseAIOnDeviceInvalidRequestException
import com.google.firebase.ai.type.GenerateContentResponse
import com.google.firebase.ai.type.GenerateObjectResponse
import com.google.firebase.ai.type.JsonSchema
import com.google.firebase.ai.type.PromptBlockedException
import com.google.firebase.ai.type.PublicPreviewAPI
import io.kotest.assertions.throwables.shouldNotThrowAnyUnit
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.assertions.throwables.shouldThrowUnit
import io.kotest.matchers.shouldBe
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.runBlocking
import org.junit.Before
Expand All @@ -51,10 +59,12 @@ internal class FallbackGenerativeModelProviderTests {
coEvery { defaultModel.generateContent(prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
val response = provider.generateContent(prompt)
shouldNotThrowAnyUnit {
val response = provider.generateContent(prompt)

response shouldBe expectedResponse
coVerify(exactly = 0) { fallbackModel.generateContent(any()) }
response shouldBe expectedResponse
coVerify(exactly = 0) { fallbackModel.generateContent(any()) }
}
}

@Test
Expand All @@ -64,27 +74,65 @@ internal class FallbackGenerativeModelProviderTests {

val provider =
FallbackGenerativeModelProvider(defaultModel, fallbackModel, precondition = { false })
val response = provider.generateContent(prompt)
shouldNotThrowAnyUnit {
val response = provider.generateContent(prompt)

response shouldBe expectedResponse
coVerify(exactly = 0) { defaultModel.generateContent(any()) }
coVerify { fallbackModel.generateContent(prompt) }
response shouldBe expectedResponse
coVerify(exactly = 0) { defaultModel.generateContent(any()) }
coVerify { fallbackModel.generateContent(prompt) }
}
}

@Test
fun `generateContent falls back when default model throws FirebaseAIException`() = runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
val exception = mockk<FirebaseAIException>()
// Test using an exception that extends FirebaseAIException
val exception = mockk<PromptBlockedException>()
coEvery { defaultModel.generateContent(prompt) } throws exception
coEvery { fallbackModel.generateContent(prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
val response = provider.generateContent(prompt)
shouldNotThrowAnyUnit {
val response = provider.generateContent(prompt)

response shouldBe expectedResponse
coVerify { fallbackModel.generateContent(prompt) }
response shouldBe expectedResponse
coVerify { fallbackModel.generateContent(prompt) }
}
}

@OptIn(PublicPreviewAPI::class)
@Test
fun `generateContent shouldn't falls back when default model throws unrelated exception`(): Unit =
runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
// Test using an exception that extends FirebaseAIOnDeviceException
val exception = mockk<ArithmeticException>()
coEvery { defaultModel.generateContent(prompt) } throws exception
coEvery { fallbackModel.generateContent(prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
shouldThrowUnit<ArithmeticException> { provider.generateContent(prompt) }
}

@OptIn(PublicPreviewAPI::class)
@Test
fun `generateContent falls back when default model throws FirebaseAIOnDeviceException`() =
runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
// Test using an exception that extends FirebaseAIOnDeviceException
val exception = mockk<FirebaseAIOnDeviceInvalidRequestException>()
coEvery { defaultModel.generateContent(prompt) } throws exception
coEvery { fallbackModel.generateContent(prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
shouldNotThrowAnyUnit {
val response = provider.generateContent(prompt)

response shouldBe expectedResponse
coVerify { fallbackModel.generateContent(prompt) }
}
}

@Test
fun `generateContent rethrows FirebaseAIException when fallback is disabled`() = runBlocking {
val exception = mockk<FirebaseAIException>()
Expand Down Expand Up @@ -120,41 +168,108 @@ internal class FallbackGenerativeModelProviderTests {
coEvery { fallbackModel.countTokens(prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
val response = provider.countTokens(prompt)
shouldNotThrowAnyUnit {
val response = provider.countTokens(prompt)

response shouldBe expectedResponse
coVerify { fallbackModel.countTokens(prompt) }
response shouldBe expectedResponse
coVerify { fallbackModel.countTokens(prompt) }
}
}

@Test
fun `generateContentStream falls back when default model throws FirebaseAIException`() =
runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
val fallbackFlow = flowOf(expectedResponse)
val exception = mockk<FirebaseAIException>()
every { defaultModel.generateContentStream(prompt) } throws exception
// Test using an exception that extends FirebaseAIOnException
val exception = mockk<PromptBlockedException>()
// throw the exception during the flow collection
every { defaultModel.generateContentStream(prompt) } returns flow { throw exception }
every { fallbackModel.generateContentStream(prompt) } returns fallbackFlow

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
shouldNotThrowAnyUnit {
val responseFlow = provider.generateContentStream(prompt)

responseFlow.first() shouldBe expectedResponse
verify { fallbackModel.generateContentStream(prompt) }
}
}

@OptIn(PublicPreviewAPI::class)
@Test
fun `generateContentStream falls back when default model throws FirebaseAIOnDeviceException`() =
runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
val fallbackFlow = flowOf(expectedResponse)
// Test using an exception that extends FirebaseAIOnDeviceException
val exception = mockk<FirebaseAIOnDeviceInvalidRequestException>()
// throw the exception during the flow collection
every { defaultModel.generateContentStream(prompt) } returns flow { throw exception }
every { fallbackModel.generateContentStream(prompt) } returns fallbackFlow

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
val responseFlow = provider.generateContentStream(prompt)
shouldNotThrowAnyUnit {
val responseFlow = provider.generateContentStream(prompt)

responseFlow shouldBe fallbackFlow
verify { fallbackModel.generateContentStream(prompt) }
responseFlow.first() shouldBe expectedResponse
verify { fallbackModel.generateContentStream(prompt) }
}
}

@Test
fun `generateContentStream rethrows non-FirebaseAIException`() = runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
val fallbackFlow = flowOf(expectedResponse)
val exception = mockk<ArithmeticException>()
// throw the exception during the flow collection
every { defaultModel.generateContentStream(prompt) } returns flow { throw exception }
every { fallbackModel.generateContentStream(prompt) } returns fallbackFlow

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
shouldThrow<ArithmeticException> { provider.generateContentStream(prompt).first() }

verify(exactly = 0) { fallbackModel.generateContentStream(prompt) }
}

@Test
fun `generateContentStream rethrows exception if a value was already emitted`() = runBlocking {
val expectedResponse: GenerateContentResponse = mockk()
val fallbackFlow = flowOf(expectedResponse)
val exception = mockk<PromptBlockedException>()
// throw the exception during the flow collection
every { defaultModel.generateContentStream(prompt) } returns
flow {
emit(expectedResponse)
throw exception
}
every { fallbackModel.generateContentStream(prompt) } returns fallbackFlow

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
// Even though it's an exception we can fall back from, we don't since a value has been
// generated
// already.
shouldThrow<PromptBlockedException> { provider.generateContentStream(prompt).collect() }

verify(exactly = 0) { fallbackModel.generateContentStream(prompt) }
}

@Test
fun `generateObject falls back when default model throws FirebaseAIException`() = runBlocking {
val schema: JsonSchema<Any> = mockk()
val expectedResponse: GenerateObjectResponse<Any> = mockk()
// Test using an exception that extends
val exception = mockk<FirebaseAIException>()
coEvery { defaultModel.generateObject(schema, prompt) } throws exception
coEvery { fallbackModel.generateObject(schema, prompt) } returns expectedResponse

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)
val response = provider.generateObject(schema, prompt)
shouldNotThrowAnyUnit {
val response = provider.generateObject(schema, prompt)

response shouldBe expectedResponse
coVerify { fallbackModel.generateObject(schema, prompt) }
response shouldBe expectedResponse
coVerify { fallbackModel.generateObject(schema, prompt) }
}
}

@Test
Expand All @@ -172,12 +287,12 @@ internal class FallbackGenerativeModelProviderTests {
fun `generateContentStream rethrows CancellationException and does not fall back`() =
runBlocking {
val exception = kotlinx.coroutines.CancellationException("cancelled")
every { defaultModel.generateContentStream(prompt) } throws exception
every { defaultModel.generateContentStream(prompt) } returns flow { throw exception }

val provider = FallbackGenerativeModelProvider(defaultModel, fallbackModel)

shouldThrow<kotlinx.coroutines.CancellationException> {
provider.generateContentStream(prompt)
provider.generateContentStream(prompt).first()
}
verify(exactly = 0) { fallbackModel.generateContentStream(any()) }
}
Expand Down
Loading
Loading