Large Language Models Termperature and relative performance

 When you read about performance and tuning for large language models the term Temperature comes up often. I find a lot of interesting description of it from "It controls creativity", "How hard it thinks". What I like to say, "temperature at 0.0 is prone to over-fit".  I avoid "high temperature is  more likely to hallucinate," as I feel the 0.0 'over fit' looks a lot like a hallucination as well.  

I rebuilt some of the deliverance code for sampling to give it an Object Oriented Design style face-lift, and prepare for adding repetition penalty. 

As always feel free to look at the code here: 

https://github.com/edwardcapriolo/deliverance/commit/0162c1daa07cba6b43d4e4f75cb95f50b0ff11b2

I was never overjoyed with was the size of the AbstractModel class. Many of the features are tightly coupled, I wanted to get the sampling into it's own class. It is not 100% clean as we still walk back into the AbstractModel to get config values and other things, but again I just like to split things out and see how they evolve.

First, Lets look at where sampling is called. Here we are in our method that Generates text:

public Response generate(UUID sessionId, PromptContext promptContext, 
GeneratorParameters generatorParameters, GenerateEvent onTokenWithTimings) { 
AbstractTensor last = batchForward(promptTokens, startPos, kvmem);
GeneratorSampler sampler = new GeneratorSampler(this, last.slice(last.shape().first() - 1), 
temperature, random.nextFloat(), logits, sampleOutput.getOutputLayerNorm());
for (int i = startPos + promptTokens.length; i < ntokens; i++) {
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;
GeneratorSampler sampler1 = new GeneratorSampler(this, output,
temperature,random.nextFloat(), logits, sampleOutput.getOutputLayerNorm());
next = sampler1.sample();

output.close();

Without diving to much into the code that overall makes sense. We need to start our process. We generate and keep doing until out of tokens or the model says END. 

Notice we added dropwizard metrics.

forward1 = abstractModel.metricRegistry.histogram("sample.foward1");
dotprod2 = abstractModel.metricRegistry.histogram("sample.dotproduct2");
fullSample = abstractModel.metricRegistry.histogram("sample.fullsample");

We want to know where the most time elapses, so we created a series of histograms. These are awesome at tracking time and percentiles. Like so:

    long start = System.nanoTime();
try (AbstractTensor embedding = layerNorm.forward(output)) {
long afterForward = System.nanoTime();
forward1.update(Math.abs(afterForward - start));

We have one histogram over the entire sample method and we introduced two other ones tactically in places that looked computationally heavy. Now, here goes the fun part!

public int sample() {
long start = System.nanoTime();
try (AbstractTensor embedding = layerNorm.forward(output)) {
long afterForward = System.nanoTime();
forward1.update(Math.abs(afterForward - start));

VectorMath.pchunk(0, abstractModel.config.vocabularySize, (chunkStart, chunkSize) -> {
abstractModel.configurableTensorProvider.get()
.dotProductChunk(logits, embedding, abstractModel.sampleOutput.getOutputLogitsWeights(), 0,
abstractModel.config.embeddingLength, chunkStart, chunkSize);
}, abstractModel.configurableTensorProvider.get().parallelSplitSize());
long afterDotProductChunk = System.nanoTime();
dotprod2.update(Math.abs(afterDotProductChunk - afterForward));

if (abstractModel.config.logitMultiplier != null) {
CausualWhisperer.LOGGER.debug("scaling logits logitMultiplier: {}", abstractModel.config.logitMultiplier);
abstractModel.configurableTensorProvider.get().scale(1.0f / abstractModel.config.logitMultiplier,
logits, 0, abstractModel.config.vocabularySize);
}
int maxi = Integer.MIN_VALUE;
double maxv = Double.NEGATIVE_INFINITY;
for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = logits.get(0, i);
if (abstractModel.config.finalLogitSoftCapping != null) {
v /= abstractModel.config.finalLogitSoftCapping;
v = (float) FastMath.tanh(v);
v = v * abstractModel.config.finalLogitSoftCapping;
logits.set(v, 0, i);
}
if (v > maxv) {
maxi = i;
maxv = v;
}
}
if (temperature == 0.0) {
CausualWhisperer.LOGGER.debug("temperature at 0 returning maxi {}", maxi);
return maxi;
}
float sum = 0;
for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = (float) FastMath.exp((logits.get(0, i) - maxv) / temperature);
sum += v;
logits.set(v, 0, i);
}
float acc = 0;
for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = logits.get(0, i) / sum;
acc += v;
if (acc >= uniformSample) {
CausualWhisperer.LOGGER.debug("accumulator {} >= uniformSample {} returning {}", acc, uniformSample, i);
return i;
}
}
CausualWhisperer.LOGGER.debug("Reached end returning {}", abstractModel.config.vocabularySize - 1);
return abstractModel.config.vocabularySize - 1;
} finally {
long end = System.nanoTime();
fullSample.update(Math.abs(end - start));
}
}

It is fairly obvious why temperature =0.0 is fast. Right in the middle of the method we exit early. 

        if (temperature == 0.0) {
CausualWhisperer.LOGGER.debug("temperature at 0 returning maxi {}", maxi);
return maxi;
}

What if the temperature is not 0? Well first we take a trip through the vocabulary (which can be pretty big)

for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = (float) FastMath.exp((logits.get(0, i) - maxv) / temperature);
sum += v;
logits.set(v, 0, i);
}

 Then we take a second pass and we can "short-circuit" if 

float acc = 0;
for (int i = 0; i < abstractModel.config.vocabularySize; i++) {
float v = logits.get(0, i) / sum;
acc += v;
if (acc >= uniformSample) {
CausualWhisperer.LOGGER.debug("accumulator {} >= uniformSample {} returning {}", acc, uniformSample, i);
return i;
}
}

 In case you were wondering "uniform sample" is a random number. Generator can use a seed to more or less have repeatable random numbers. But that is the magic.

So what does this crap all do? Well.... o wise model 'tell me the capital of New York?'

 

@Test
public void gemmaTest() throws IOException {
ModelFetcher fetch = new ModelFetcher("tjake", "gemma-2-2b-it-JQ4");
File f = fetch.maybeDownload();
MetricRegistry mr = new MetricRegistry();
TensorCache tensorCache = new TensorCache(mr);
NativeSimdTensorOperations operation = new NativeSimdTensorOperations(new ConfigurableTensorProvider(tensorCache).get());
try (AbstractModel m = ModelSupport.loadModel(f, DType.F32, DType.I8, new ConfigurableTensorProvider(operation),
mr, tensorCache, new KvBufferCacheSettings(true), fetch)) {
String prompt = "What is the capital of New York, USA?";
PromptSupport.Builder g = m.promptSupport().get().builder()
.addUserMessage(prompt);
Assertions.assertEquals("<start_of_turn>user\n" +
"What is the capital of New York, USA?<end_of_turn>\n" +
"<start_of_turn>model\n",g.build().getPrompt());
var uuid = UUID.randomUUID();

Response k = m.generate(uuid, g.build(), new GeneratorParameters().withTemperature(0.6f),
new DoNothingGenerateEvent());
assertTrue(k.responseText.contains("Albany"));

}

 
The capital of New York, USA is **Albany**. 

<end_of_turn> 

And what of our lovely histograms?

System.out.println(Arrays.toString(mr.histogram("sample.fullsample").getSnapshot().getValues()));
System.out.println( mr.histogram("sample.fullsample").getSnapshot().getMean());
System.out.println( mr.histogram("sample.fullsample").getSnapshot().get99thPercentile());


System.out.println(Arrays.toString(mr.histogram("sample.forward1").getSnapshot().getValues()));
System.out.println( mr.histogram("sample.forward1").getSnapshot().getMean());
System.out.println( mr.histogram("sample.forward1").getSnapshot().get99thPercentile());

System.out.println(Arrays.toString(mr.histogram("sample.dotproduct2").getSnapshot().getValues()));
System.out.println( mr.histogram("sample.dotproduct2").getSnapshot().getMean());
System.out.println( mr.histogram("sample.dotproduct2").getSnapshot().get99thPercentile());



[56236602, 56416227, 57041137, 57419572, 57515010, 57674425, 57845190, 57908894, 58208113, 58335658, 58714224, 58760750, 58827393, 59583618, 59638843, 59859574, 61209940, 62987280, 65168698, 66211778, 68409210, 189747354, 4870599278]
2.582465077790329E8
4.870599278E9
 

[]
0.0
0.0
 

[49287787, 49379628, 49599852, 50590065, 50699882, 50710310, 51019596, 51028821, 51109579, 51622137, 51768809, 51832366, 52053602, 52128640, 52225558, 53176911, 54210671, 55711002, 57859406, 58511262, 59936317, 61586098, 4755855591]
2.416556005367109E8
4.755855591E9

Comments

Popular posts from this blog

For the love of Java

Guided Choice: Get only the answer you want with no fluff