Pre-Requisite: Static Typing
Before using this guide, familiarise yourself with the typing
module in Python.
Structured Output
LangChain simplifies the process of obtaining structured output from the LLM.
It leverages on Pydantic, which is a data validation library commonly used by many frameworks in Python.
Pydantic Schema:
Initialise a class which inherits from BaseModel
and enter the output keys you desire with type annotations.
Use Field
to add extra information for the LLM to understand what data to place in the respective keys
from pydantic import BaseModel, Field
class Marks(BaseModel):
english: int = Field(..., description = "Marks for English")
math: int = Field(..., description = "Marks for Math")
science: int = Field(..., description = "Marks for Science")
Binding Schema to LLM
To ensure consistent output, use the .with_structured_output()
method and pass in the Pydantic Model.
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model = 'gpt-4o-mini', temperature = 0)
llm_with_tools = llm.with_structured_output(Marks)
Testing the LLM
Let's invoke the LLM with a prompt containing different subject marks and see the result:
response = llm_with_tools.invoke("""
George scored 50 marks which shows great improvement from the previous semester.
However, he didn't do as well for Math but managed to procure 75 marks.
The Science marks were moderated now and he scored 62 + 8 marks.
""")
print(response)
print(dict(response)) # Convert to Dictionary
print(response.json()) # Convert to JSON
Output:
english=50 math=75 science=70
{'english': 50, 'math': 75, 'science': 70}
{"english":50,"math":75,"science":70}
Hands On: Perfect Squares
You are given this prompt:
"1,2,3,4,5,6,7,8,9,10,22,33,44,88,101,222,64"
Your task is to separate the perfectly squared numbers and non-perfectly squared numbers.
Bind a Pydantic model you created for the task to the LLM and invoke it with the prompt above:
Solution
from pydantic import BaseModel, Field
from typing import List
# Instantiate Pydantic Model
class Numbers(BaseModel):
perfect_squares: List[int] = Field(..., description = "All perfect squares should be here")
non_perfect_squares: List[int] = Field(..., description = "All non perfect squares should not be here")
from langchain_openai import ChatOpenAI
# Instantiate LLM
llm = ChatOpenAI(model = 'gpt-4o-mini', temperature = 0)
# Bind Output Rules to LLM
llm_with_structured_output = llm.with_structured_output(Numbers)
# Invoke LLM to prompt
response = llm_with_structured_output.invoke("1,2,3,4,5,6,7,8,9,10,22,33,44,88,101,222,64")
# Show Results
print(response)