Pre-Requisite: Static Typing
Before using this guide, familiarise yourself with the typing
module in Python.
Structured Output
LangChain simplifies the process of obtaining structured output from the LLM.
It leverages on Pydantic, which is a popular data validation library.
1. Define Pydantic Schema:
Initialise a class which inherits from BaseModel
and define your desired output structure using type annotations.
The Field
class helps provide context to the LLM about what data should go into each field.
from pydantic import BaseModel, Field
class Marks(BaseModel):
english: int = Field(..., description = "Marks for English")
math: int = Field(..., description = "Marks for Math")
science: int = Field(..., description = "Marks for Science")
2. Define Schema to Chat Model
Use the .with_structured_output()
method to ensure that the LLM returns data in your specified format:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model = 'gpt-4o-mini', temperature = 0)
llm_with_tools = llm.with_structured_output(Marks)
Testing the LLM
Let's invoke the LLM with a prompt containing different subject marks and see the result:
response = llm_with_tools.invoke("""
George scored 50 marks which shows great improvement from the previous semester.
However, he didn't do as well for Math but managed to procure 75 marks.
The Science marks were moderated now and he scored 62 + 8 marks.
""")
print(response)
print(response.model_dump()) # Convert to Dictionary
print(response.model_dump_json()) # Convert to JSON
Output:
english=50 math=75 science=70
{'english': 50, 'math': 75, 'science': 70}
{"english":50,"math":75,"science":70}
Hands On: Perfect Squares
You are given this prompt:
"1,2,3,4,5,6,7,8,9,10,22,33,44,88,101,222,64"
Your task is to separate the perfectly squared numbers and non-perfectly squared numbers.
Bind a Pydantic model you created for the task to the LLM and invoke it with the prompt above:
Solution
from pydantic import BaseModel, Field
from typing import List
# Instantiate Pydantic Model
class Numbers(BaseModel):
perfect_squares: List[int] = Field(..., description = "All perfect squares should be here")
non_perfect_squares: List[int] = Field(..., description = "All non perfect squares should not be here")
from langchain_openai import ChatOpenAI
# Instantiate LLM
llm = ChatOpenAI(model = 'gpt-4o-mini', temperature = 0)
# Bind Output Rules to LLM
llm_with_structured_output = llm.with_structured_output(Numbers)
# Invoke LLM to prompt
response = llm_with_structured_output.invoke("1,2,3,4,5,6,7,8,9,10,22,33,44,88,101,222,64")
# Show Results
print(response)