SAMMO tutorial

SAMMO is a flexible, easy-to-use library for running and optimizing prompts for Large Language Models (LLMs). In this example, we will walk through how to optimize prompts for an LLM of our choosing on Big Bench Hard. This process can be repeated for as many LLMs as we want to route between.

For the purpose of prompt optimization, SAMMO treats the optimization as a graph search problem. Each node in the graph represents a possible prompt, with edges defined by a mutator. By defining a set of initial prompt candidates and mutators, SAMMO will attempt to search the graph for the best possible prompt as evaluated by a specified metric.

Installation

pip install sammo

Initialization

First, we'll create a new file and bring in the following import statements:

import sammo
from sammo.runners import OpenAIChat
from sammo.utils import serialize_json
from sammo.base import Template, EvaluationScore
from sammo.components import Output, GenerateText, ForEach, Union
from sammo.extractors import ExtractRegex
from sammo.data import DataTable
from sammo.runners import BaseRunner, LLMResult
from sammo.instructions import MetaPrompt, Section, Paragraph, InputData
from sammo.dataformatters import PlainFormatter
from sammo.search_op import one_of
from sammo.mutators import BagOfMutators, InduceInstructions, Paraphrase
from sammo.search import BeamSearch
from typing import List, Optional
from notdiamond import NotDiamond, LLMConfig
import json
import requests
import os

Next we'll define the LLM we want to optimize our prompt for. In SAMMO, each LLM is referred to as a Runner.

_ = sammo.setup_logger("WARNING")  # we're only interested in warnings for now

runner = OpenAIChat(
    model_id="gpt-3.5-turbo",
    api_config={"api_key": os.environ['OPENAI_API_KEY']},
    timeout=30,
)

SAMMO supports OpenAI and Azure but if the LLM you are interested in is not available, you can create a custom one. In fact, you can create any custom LLM runner that Not Diamond supports using the notdiamond package:

class CustomRunner(BaseRunner):
    def __init__(
        self, llm_config: LLMConfig, api_key: Optional[str] = None, **kwargs
    ):
        super().__init__(**kwargs)
        self.llm = llm_config
        self.client = NotDiamond(api_key=api_key)

    async def generate_text(
        self,
        prompt: str,
        priority: int = 0,
        **kwargs,
    ) -> LLMResult:
        request = dict(
            messages=[{"role": "user", "content": prompt}],
        )
        fingerprint = serialize_json(
            {"generative_model_id": self._model_id, **request}
        )
        return await self._execute_request(request, fingerprint, priority)

    async def _call_backend(self, request: dict) -> dict:
        result, _, _ = self.client.chat.completions.create(
            messages=request["messages"], model=[self.llm]
        )
        return {"response": result.content}

    def _to_llm_result(
        self, request: dict, json_data: dict, fingerprint: str | bytes
    ):
        return LLMResult(json_data["response"])

llm = LLMConfig(
    provider="anthropic",
    model="claude-3-haiku-20240307",
    api_key="YOUR_ANTHROPIC_API_KEY"
)

runner = CustomRunner(
    model_id=f"{llm.provider}/{llm.model}",
    llm=llm,
    api_config={"api_key": llm.api_key},
    timeout=30
)

Next, we will use Big Bench Hard as the dataset for demonstrating prompt optimization and define the evaluation metric accuracy as the objective function we want to optimize for.

def load_data(
    url="https://github.com/google/BIG-bench/raw/main/bigbench/benchmark_tasks/implicatures/task.json",
):
    task = json.loads(requests.get(url).content)
    # convert label to single string
    for x in task["examples"]:
        x["output"] = max(x["target_scores"], key=x["target_scores"].get)

    return DataTable.from_records(
        task["examples"],
        input_fields="input",
    )


def accuracy(y_true: DataTable, y_pred: DataTable) -> EvaluationScore:
    y_true = y_true.outputs.normalized_values()
    y_pred = y_pred.outputs.normalized_values()
    n_correct = sum([y_p == y_t for y_p, y_t in zip(y_pred, y_true)])

    return EvaluationScore(n_correct / len(y_true))
  
  
mydata = load_data()
d_train = mydata.sample(10, seed=42)

Defining the set of initial candidates

This is the initial prompt template:

"""
Instructions:
Does Speaker 2's answer mean yes or no?

Output labels: yes, no
{input}
Output:
"""

However, we want to optimize this prompt to maximize performance. We will use SAMMO's MetaPrompt module, which will allow us to deconstruct the prompt into various components and optimize each of them individually. Here we will primarily use Paragraph to define each line in the prompt.

class InititialCandidates:
    def __init__(self, dtrain):
        self.dtrain = dtrain

    def __call__(self):
        example_formatter = PlainFormatter(
            all_labels=self.dtrain.outputs.unique(), orient="item"
        )

        labels = self.dtrain.outputs.unique()
        instructions = MetaPrompt(
            [
                Paragraph("Instructions: "),
                Paragraph(
                    one_of(
                        [
                            "Does Speaker 2's answer mean yes or no?",
                            "Find the best output label given the input.",
                        ]
                    ),
                    id="instructions",
                ),
                Paragraph("\n"),
                Paragraph(
                    f"Output labels: {', '.join(labels)}\n" if len(labels) <= 10 else ""
                ),
                Paragraph(InputData()),
                Paragraph("Output: "),
            ],
            render_as="raw",
            data_formatter=example_formatter,
        )

        return Output(
            instructions.with_extractor("raise"),
            minibatch_size=1,
            on_error="empty_result",
        )

The key part of the code above that allows us to specify the instruction as a tunable parameter is:

Paragraph(
  one_of(
    [
      "Does Speaker 2's answer mean yes or no?",
      "Find the best output label given the input.",
    ]
  ),
  id="instructions",
),

This line first defines 2 possible initial options and then assigns an id="instructions" to it. This ID can then be used in mutators as a way to specify the part of the prompt template you want it to modify. The one_of operator here means the search algorithm will use one of the 2 options as the starting point of the search.

Define a set of mutation operators

Next, we'll define various mutators, which are operators that will mutate a certain part of the prompt. This is where we will use the id to tell the mutator where to edit:

mutation_operators = BagOfMutators(
    InititialCandidates(d_train),
    InduceInstructions({"id": "instructions"}, d_train),
    Paraphrase({"id": "instructions"}),
    sample_for_init_candidates=False,
)

Run optimization

🚧

Running prompt optimization will incur inference costs

Remember that running prompt optimization will incur inference costs based on the number of iterated LLM calls.

Finally, we'll use Beam Search to iterate through mutations and optimize our prompts. If we'd like to leverage other strategies, SAMMO also provides additional search methods.

In each step of the beam search, SAMMO will sample a set of mutation operators and apply them to the current set of active candidates:

prompt_optimizer = BeamSearch(
            runner,
            mutation_operators,
            accuracy,
            maximize=True,
            depth=3,
            mutations_per_beam=2,
            n_initial_candidates=4,
            beam_width=4,
            add_previous=True,
    )
prompt_optimizer.fit(d_train)
prompt_optimizer.show_report()

print("Best prompt")
print(prompt_optimizer.best_prompt)

Wrapping up

In this example, we showed how to optimize prompts for GPT-3.5 on Big Bench Hard. We can run the same process for any other LLM simply by swapping it in. With our optimized prompts and evaluation scores for each LLM, we can then train a custom router that always calls the best (prompt, model) combination.