Exclude Top Choice takes VS repetition penalty
A few days ago I hacked top_logprobs into deliverance:
Logprobs shows you what other tokens were close to being chosen and the logprobs (logarithm of the probability) of how close they were.
When I had never done something like this before and the math turned out to be not so bad. Most of what I learned I got from this excellent article and this one.
public static void logSumExpTensor(AbstractTensor result, AbstractTensor input) {
float logsumexp = (float) logSumExp(input);
for (int i = 0; i < input.size(); i++) {
float v = input.get(0, i);
result.set(v - logsumexp, 0, i);
}
}
public static double logSumExp(AbstractTensor x){
float sum = 0.0f;
for (int i = 0; i < x.size(); i++) {
sum += (float) FastMath.exp(x.get(0, i));
}
return (float) FastMath.log(sum);
}
All this math stuff had me a little bit more confident editing the Sampler code. So I decided to keep going. One thing I have found when using the smaller models is they tend to get into loops frequently. The class solution for looping is to use Frequency Penalty and Presence Penalty.
I started doing a lot of reading about the two options and I found a lot of articles explaining their pitfalls. "If the word comes up often stop saying it", seems on the surface like a great idea. You can not go in a loop, go in a loop, go in a loop. If the word "loop" gets penalized the repetition is broken. However the problem is....well words like "is" comes up often and they get penalized. The comma "," token will get penalized, and what can happen is the model with starting using ":" in place of "," . That seems bad too. Now you need a "black list" of common words to not penalize.
Enter XTC: Exclude Top Choice
Well someone came up with a solution I really like XTC. The idea is pretty cool, randomly when generating a token, pick a different one. Which token to pick? The lowest value that is over a threshold.
If the already generated string is "I am hungry, I would like to go to []". The existing sampler might believe "eat" has the highest score. What XTC does is generates a list of possibilities [ 'shopping : 50 ', 'to : 60', 'home: 90'], then if the threshold was 57 it would pick "to".
I implemented here. It is a little bit to annotate, but Ill show you what it does. So here is your normal prompt:
@Test
public void logProbs() throws JsonProcessingException {
AbstractModel m = Gemma2Suite.getOrCreate();
String prompt = "Pick a random number between 1 and 9. Replay with only the pick.";
PromptSupport.Builder g = m.promptSupport().get().builder()
.addUserMessage(prompt);
var uuid = UUID.randomUUID();
Response response = m.generate(uuid, g.build(), new GeneratorParameters()
.withTemperature(0.0f)
.withMaxTokens(300)
.withLogProbs(true)
.withTopLogProbs(10)
, new DoNothingGenerateEvent());
assertEquals("I picked the number **5**! \uD83C\uDFB2 \n" +
"<end_of_turn>".trim(), response.responseText.trim());
}
We can set probability of turning on the feature.
Response k = m.generate(uuid, g.build(), new GeneratorParameters()
.withTemperature(0.0f)
.withMaxTokens(300)
.withLogProbs(true)
.withTopLogProbs(10)
.withXtcThreshold(0.1f)
.withXtcProbability(0.5f)
, new DoNothingGenerateEvent());
The effect is this:
[main] ERROR io.teknek.deliverance.CausualWhisperer - xtc: picked maxi: '
[main] ERROR io.teknek.deliverance.CausualWhisperer - xtc: 4 maxi: 5
[main] ERROR io.teknek.deliverance.CausualWhisperer - xtc: **. maxi: **
assertEquals("I picked the number **4**. <end_of_turn>".trim(), k.responseText.trim());
Now this should hopefully break the loops without putting penalties on words.

Comments
Post a Comment