rsa_dspy

Recursive Self-Aggregation (RSA) using DSPy modules based on the paper https://rsa-llm.github.io/

source

RSACandidate


def RSACandidate(
    id:str, loop_id:int, task_prompt:str, signature:type=None, candidates_str:str=None, response:str=None,
    parent_ids:list=None
):

A candidate response in the RSA algorithm


source

RSA


def RSA(
    task_prompt:str, # The main task/question to solve
    solver:NoneType=None, # task signature
    aggregator:NoneType=None, # aggregator signature
    N:int=4, # Population size (candidates per loop)
    K:int=3, # Number of candidates to aggregate
    loops:int=2, # Number of aggregation loops
    history:list=None, # History of all candidates
):

Recursive Self-Aggregation algorithm using DSPy

a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', solver=TaskSolver, aggregator=AggregateResponses)
print(a)
c1 = RSACandidate(id='c1', loop_id=0, task_prompt='test', response='Answer A')
c2 = RSACandidate(id='c2', loop_id=0, task_prompt='test', response='Answer B')

print(a._mk_candidate_str([c1, c2]))

source

RSA.get_prompts


def get_prompts(
    loop_id, cands:NoneType=None
):

Generate candidate prompts for a given loop: N initial candidates, or all C(n,K) combinations for aggregation

# Test loop 0
cands = a.get_prompts(loop_id=0)
test_eq(len(cands), a.N)
test_eq(cands[0].signature, a.solver)
# Test loop 1+ (with prior candidates)
prior = L(RSACandidate(id=str(uuid.uuid4()), loop_id=0, task_prompt='test', response=f'Answer {i}') for i in range(8))
cands = a.get_prompts(loop_id=1, cands=prior)
test_eq(len(cands), a.N)
print(cands[0].task_prompt)

Configuration

RSA-DSPy uses dspy for LLM calls. Configure your LM globally:

dspy.configure(lm=dspy.LM('openrouter/google/gemini-3-flash-preview', temperature=1.0))

See DSPy’s LM documentation for supported providers.

dspy.configure(lm=dspy.LM('openrouter/google/gemini-3-flash-preview', temperature=1.0, cache=False))

cands = a._run_loop(loop_id=0)
test_eq(len(cands), a.N)
assert all(c.response is not None for c in cands)
assert cands[0].response != cands[1].response
cands[0].response

source

RSA.run


def run(
    
):

Run the full RSA algorithm for the configured number of loops and return the final candidate pool

a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', solver=TaskSolver, aggregator=AggregateResponses, loops=2)
result = a.run()
print(f"Final pool: {len(result)}, History: {len(a.history)}")

source

RSA.final_aggregate


def final_aggregate(
    method:str='llm', signature:NoneType=None
):

Final aggregation: one LLM call to aggregate all final loop candidates

# Test 'llm' aggregation
result = a.final_aggregate(method='llm')
assert isinstance(result, dspy.Prediction)
assert len(result.response) > 0
print(result)

# Test 'random' aggregation
result = a.final_aggregate(method='random')
assert isinstance(result, dspy.Prediction)
assert len(result.response) > 0