Guided Choice: Get only the answer you want with no fluff

 In my last blog we refactored the text sampling code in deliverance. I did that to prepare the code to take on something clever. You know LLM sometimes give you the answer inside a sea of text and never stop talking (like me). A solution to that is "structured outputs", which helps you control what the inference engine will produce.

There are several forms of structured outputs. A simple one is guided choice. Effectively you give the inference engine a prompt and a list of choices. It should only answer with one of them. 

Something like: 
prompt= "What is the best month for vacation", 
choices = ["January", "February"] 

So lets code it up! As always here is the code  Guided Choice commit

One thing to think about ahead of time. Call it a "trick". LLMs don't answer in words, they answer in tokens.  If you ask an LLM a question like "Who is the best NFL team". It might have "Giants" in the vocabulary, or it might not. Even if "Giants" is in its pre-trained vocab it may for one reason or another answer like this:

token[0] "Gian" 
token[1] "ts"

Also as you are iterating many candidates may match:

/*
candidate found token 12662 Gi Gi
candidate found token 65060 Gian Gian
candidate found token 235319 G G
return Gian
candidate found token 617 ts Giants
return ts
Giants
*/


This doesn't make things too difficult at least in the simple cases but the inference engine has to juggle what the Sampler is returning and when it should stop, commonly called the STOP_REASON

Lets just look at some code. The first part is the already existing sampler we do our dot products and smoothing:

public int sample() {
long start = System.nanoTime();
try (AbstractTensor embedding = layerNorm.forward(output)) {
long afterForward = System.nanoTime();
forward1.update(Math.abs(afterForward - start));
        //DO DOT PRODUCT FOR EVERY CHUNK

VectorMath.pchunk(0, abstractModel.config.vocabularySize, (chunkStart, chunkSize) -> {
abstractModel.configurableTensorProvider.get()
.dotProductChunk(logits, embedding, abstractModel.sampleOutput.getOutputLogitsWeights(), 0,
abstractModel.config.embeddingLength, chunkStart, chunkSize);
}, abstractModel.configurableTensorProvider.get().parallelSplitSize());
long afterDotProductChunk = System.nanoTime();
dotprod2.update(Math.abs(afterDotProductChunk - afterForward));

if (abstractModel.config.logitMultiplier != null) {
CausualWhisperer.LOGGER.debug("scaling logits logitMultiplier: {}", abstractModel.config.logitMultiplier);
abstractModel.configurableTensorProvider.get().scale(1.0f / abstractModel.config.logitMultiplier,
logits, 0, abstractModel.config.vocabularySize);
}
int maxi = Integer.MIN_VALUE;
double maxv = Double.NEGATIVE_INFINITY;
String bestMatch = "";


Here is the part where it gets interesting:

As the model returns tokens to us we have to determine how much we "like" them. In this case the next token has to fit a specific shape (the user is guiding us to an outcome "pick1" pick2").  We have out existing buffer and if this token can be appended to it we are headed in the right direction, we accept that token and move forward.

 for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = logits.get(0, i);
if (abstractModel.config.finalLogitSoftCapping != null) {
v /= abstractModel.config.finalLogitSoftCapping;
v = (float) FastMath.tanh(v);
v = v * abstractModel.config.finalLogitSoftCapping;
logits.set(v, 0, i);
}
if (v > maxv) {
String decodedToken = tokenizer.decode(i);
String entire = current + decodedToken;
if (!decodedToken.isEmpty() && currentChoices.stream().anyMatch(ch -> ch.startsWith(entire))) {
LOG.debug("candidate found token {} {} {} ", i , decodedToken, entire);
if (entire.length() > bestMatch.length()) {
maxi = i;
maxv = v;
bestMatch = decodedToken;
}
}
}
}
if (maxi != Integer.MIN_VALUE) {
return maxi;
}
CausualWhisperer.LOGGER.debug("Reached end returning {}", abstractModel.config.vocabularySize - 1);
return abstractModel.config.vocabularySize - 1;
} finally {
long end = System.nanoTime();
fullSample.update(Math.abs(end - start));
}

 

Lets look at how we call this from AbstractModel.

if (generatorParameters.guidedChoice.isPresent()) {
GuidedChoiceSampler sampler1 = new GuidedChoiceSampler(this, output, logits, sampleOutput.getOutputLayerNorm(), tokenizer, generatorParameters.guidedChoice.get(), responseText);
next = sampler1.sample();
} else {
GeneratorSampler sampler1 = new GeneratorSampler(this, output, temperature, random.nextFloat(), logits, sampleOutput.getOutputLayerNorm());
next = sampler1.sample();
}
output.close();
kvmem.incrementContextPosition();
if (config.eosTokens.contains(next)) {
reason = FinishReason.STOP_TOKEN;
break;
}
...
if (generatorParameters.guidedChoice.isPresent()) {
if (generatorParameters.guidedChoice.get().contains(responseText.toString())) {
reason = FinishReason.STOP_TOKEN;
break;
}
}

That seems too easy doesn't it? Well I guess that is the magic of math.

How does one use something like this? I am glad you asked. Ow mighty LLM tell us which of these 3 NFL teams (Giants, Jets, Seahawks) do not play in NY?

@Test
public void gemmaGuidedTestNeg() throws IOException {
ModelFetcher fetch = new ModelFetcher("tjake", "gemma-2-2b-it-JQ4");
File f = fetch.maybeDownload();
MetricRegistry mr = new MetricRegistry();
TensorCache tensorCache = new TensorCache(mr);
NativeSimdTensorOperations operation = new NativeSimdTensorOperations(new ConfigurableTensorProvider(tensorCache).get());
try (AbstractModel m = ModelSupport.loadModel(f, DType.F32, DType.I8, new ConfigurableTensorProvider(operation),
mr, tensorCache, new KvBufferCacheSettings(true), fetch)) {
String prompt = "Which NFL franchise does not play in New York?";
PromptSupport.Builder g = m.promptSupport().get().builder()
.addUserMessage(prompt);
var uuid = UUID.randomUUID();

Response k = m.generate(uuid, g.build(), new GeneratorParameters()
.withTemperature(0.0f)
.withGuidedChoice(List.of("Giants", "Jets", "Seahawks")),
new GenerateEvent() {
@Override
public void emit(int next, String nextRaw, String nextCleaned, float timing) {
System.out.println(nextCleaned);
}
});
assertTrue(k.responseText.contains("Seahawks"));
}

}

Wow nice. Interestingly in my testing you find some other funny stuff. I asked: 
"Which franchise does not play in New York".

I didn't dive in but I am guessing there wasn't enough context but strange:

Giants
<unused99>
<unused99>
<unused99>
<unused99>
<unused99>
<unused99>
<unused99> 
 
So this is a result from Gemma2 which has these unused tokens. I am guessing the model uses that to say it is off the rails :) But when you write a java inference engine from scratch and figure it out as you go. That is tomorrows fun problem!

Also fun: this approach is very prone to misspellings and case sensitivity. Once I did this and gave the choices "Giant" and "Jets" it picked "Jets" because "Giant" isn't a football team. 

Comments

Popular posts from this blog

For the love of Java

Large Language Models Termperature and relative performance