-
Notifications
You must be signed in to change notification settings - Fork 336
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added a `Guidance` method to `LLamaTokenDataArray` which applies classifier free guidance * Factored out a safer `llama_sample_apply_guidance` method based on spans * Created a guided sampling demo using the batched executor * fixed comment, "classifier free" not "context free" * Rebased onto master and fixed breakage due to changes in `BaseSamplingPipeline` * Asking user for guidance weight * Progress bar in batched fork demo * Improved fork example (using tree display) * Added proper disposal of resources in batched examples * Added some more comments in BatchedExecutorGuidance
- Loading branch information
1 parent
364259a
commit 7d84625
Showing
6 changed files
with
250 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
using LLama.Batched; | ||
using LLama.Common; | ||
using LLama.Native; | ||
using LLama.Sampling; | ||
using Spectre.Console; | ||
|
||
namespace LLama.Examples.Examples; | ||
|
||
/// <summary> | ||
/// This demonstrates using a batch to generate two sequences and then using one | ||
/// sequence as the negative guidance ("classifier free guidance") for the other. | ||
/// </summary> | ||
public class BatchedExecutorGuidance | ||
{ | ||
private const int n_len = 32; | ||
|
||
public static async Task Run() | ||
{ | ||
string modelPath = UserSettings.GetModelPath(); | ||
|
||
var parameters = new ModelParams(modelPath); | ||
using var model = LLamaWeights.LoadFromFile(parameters); | ||
|
||
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); | ||
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); | ||
var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f); | ||
|
||
// Create an executor that can evaluate a batch of conversations together | ||
using var executor = new BatchedExecutor(model, parameters); | ||
|
||
// Print some info | ||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); | ||
Console.WriteLine($"Created executor with model: {name}"); | ||
|
||
// Load the two prompts into two conversations | ||
using var guided = executor.Prompt(positivePrompt); | ||
using var guidance = executor.Prompt(negativePrompt); | ||
|
||
// Run inference to evaluate prompts | ||
await AnsiConsole | ||
.Status() | ||
.Spinner(Spinner.Known.Line) | ||
.StartAsync("Evaluating Prompts...", _ => executor.Infer()); | ||
|
||
// Fork the "guided" conversation. We'll run this one without guidance for comparison | ||
using var unguided = guided.Fork(); | ||
|
||
// Run inference loop | ||
var unguidedSampler = new GuidedSampler(null, weight); | ||
var unguidedDecoder = new StreamingTokenDecoder(executor.Context); | ||
var guidedSampler = new GuidedSampler(guidance, weight); | ||
var guidedDecoder = new StreamingTokenDecoder(executor.Context); | ||
await AnsiConsole | ||
.Progress() | ||
.StartAsync(async progress => | ||
{ | ||
var reporter = progress.AddTask("Running Inference", maxValue: n_len); | ||
for (var i = 0; i < n_len; i++) | ||
{ | ||
if (i != 0) | ||
await executor.Infer(); | ||
// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any | ||
// guidance. This serves as a comparison to show the effect of guidance. | ||
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>()); | ||
unguidedDecoder.Add(u); | ||
unguided.Prompt(u); | ||
// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation | ||
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file). | ||
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>()); | ||
guidedDecoder.Add(g); | ||
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt). | ||
guided.Prompt(g); | ||
guidance.Prompt(g); | ||
// Early exit if we reach the natural end of the guided sentence | ||
if (g == model.EndOfSentenceToken) | ||
break; | ||
// Update progress bar | ||
reporter.Increment(1); | ||
} | ||
}); | ||
|
||
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]"); | ||
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]"); | ||
} | ||
|
||
private class GuidedSampler(Conversation? guidance, float weight) | ||
: BaseSamplingPipeline | ||
{ | ||
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) | ||
{ | ||
} | ||
|
||
public override ISamplingPipeline Clone() | ||
{ | ||
throw new NotSupportedException(); | ||
} | ||
|
||
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens) | ||
{ | ||
if (guidance == null) | ||
return logits; | ||
|
||
var logitsCopy = logits.ToArray(); | ||
|
||
// Get the logits generated by the guidance sequences | ||
var guidanceLogits = guidance.Sample(); | ||
|
||
// Use those logits to guide this sequence | ||
NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight); | ||
|
||
return logitsCopy; | ||
} | ||
|
||
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens) | ||
{ | ||
candidates.Temperature(ctx, 0.8f); | ||
candidates.TopK(ctx, 25); | ||
|
||
return candidates.SampleToken(ctx); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters