Dependency Injection in Workflows

Dependency injection provides a clean way to access shared resources in your workflows and activities. It automatically handles the creation and management of resources like database connections, API clients, or configuration objects, so you don't have to manually create and pass these around in your code.

How it works

How it works

The system supports multiple ways to provide dependencies, allowing you to choose the approach that best fits your resource's requirements:

  1. Synchronous functions - Ideal for simple configuration values or in-memory objects
  2. Asynchronous functions - Perfect for I/O-bound dependencies like database connections
  3. Context managers - Best for resources that require cleanup operations
  4. Generators - Useful for streaming data or complex resource management

When you declare a dependency in your activity or workflow, the system automatically provides that resource when needed, handling all the lifecycle management for you.

Dependency Lifecycle and Resource Sharing

Important: All dependencies defined with Depends() are initialized once when the worker starts up and then shared across all activity executions. This means:

  • Single Instance: The same instance is reused for every activity call
  • Resource Efficiency: Reduces connection overhead and resource consumption
  • Connection Pooling: Database connections, API clients, and other resources are maintained and reused

This approach is particularly beneficial for:

  • Database connections: Avoids creating new connections for each query
  • API clients: Maintains persistent connections and authentication
  • Configuration objects: Loads configuration once and shares it across all activities
Defining Dependencies

Defining Dependencies

Synchronous Function Provider

For simple configuration values or objects that don't require async initialization:

def get_config() -> dict:
    """Provides application configuration"""
    return {
        "timeout": 30,
        "retries": 3,
        "api_url": "https://api.example.com"
    }

Asynchronous Function Provider

For dependencies that require async initialization, like database connections:

async def get_db_connection() -> DatabaseConnection:
    """Creates and returns a database connection"""
    conn = await DatabaseConnection.create("postgres://user:pass@localhost/db")
    return conn

Context Manager Provider

For resources that need proper cleanup, like sessions that require logout:

from contextlib import contextmanager

@contextmanager
def get_logged_in_session():
    """Provides a session with automatic login/logout"""
    session = Session()
    session.login()
    try:
        yield session
    finally:
        session.logout()

Async Context Manager Provider

For async resources that need cleanup, like database connections:

from contextlib import asynccontextmanager

@asynccontextmanager
async def get_db_connection_with_cleanup():
    """Provides a database connection with proper cleanup"""
    conn = await DatabaseConnection.create("postgres://...")
    try:
        yield conn
    finally:
        await conn.close()

Generator Provider

For streaming data or complex resource management:

def get_data_stream():
    """Provides a stream of data"""
    with open("data.txt") as f:
        for line in f:
            yield line.strip()

Async Generator Provider

For async streaming data sources:

async def get_async_data_stream():
    """Provides an async stream of data"""
    async with aiofiles.open("data.txt") as f:
        async for line in f:
            yield line.strip()
Using Dependencies in Activities

Using Dependencies in Activities

Activities can declare their dependencies using the Depends() marker. The system will automatically provide these dependencies when the activity executes:

import asyncio
import mistralai.workflows as workflows
from mistralai.workflows import Depends
from contextlib import contextmanager
from pydantic import BaseModel


def get_config() -> dict:
    return {"timeout": 30, "retries": 3, "api_url": "https://api.example.com"}


class FakeDB:
    def __init__(self):
        self.records = []

    def insert(self, name: str):
        self.records.append(name)
        return len(self.records)


async def get_db() -> FakeDB:
    return FakeDB()


@contextmanager
def get_session():
    session = {"logged_in": True, "events": []}
    try:
        yield session
    finally:
        session["logged_in"] = False


@workflows.activity()
async def create_user(
    name: str,
    db: FakeDB = Depends(get_db),
    config: dict = Depends(get_config),
    session: dict = Depends(get_session),
) -> dict:
    record_id = db.insert(name)
    session["events"].append("user_created")
    return {
        "status": "success",
        "record_id": record_id,
        "timeout_used": config["timeout"],
    }


class Input(BaseModel):
    name: str


@workflows.workflow.define(name="di_workflow")
class DIWorkflow:
    @workflows.workflow.entrypoint
    async def run(self, params: Input) -> dict:
        return await create_user(params.name)


async def main():
    result = await workflows.execute_workflow(
        DIWorkflow,
        params=Input(name="Alice"),
    )
    print(result)
Common Use Cases

Common Use Cases

Database Connections

A more complete database connection example with proper session management:

async def get_db_connection():
    """Provides a database connection with proper cleanup"""
    engine = await create_db_engine("postgres://...")
    SessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)

    @asynccontextmanager
    async def get_session() -> AsyncIterator[AsyncSession]:
        async with SessionLocal() as session:
            try:
                yield session
            finally:
                await session.close()

    return get_session

API Clients

Example of a payment service client with initialization:

async def get_payment_client() -> PaymentServiceClient:
    """Creates and initializes a payment service client"""
    client = PaymentServiceClient(api_key="your_key")
    await client.initialize()
    return client