tweaks and add a simple test

This commit is contained in:
Andrej Karpathy
2023-07-23 14:52:08 +00:00
parent f499d9d2b5
commit 9414e7a45e
7 changed files with 73 additions and 48 deletions
+11 -11
View File
@@ -288,19 +288,19 @@ class Transformer(nn.Module):
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
logits = logits[:, -1, :] # crop to just the final time step
if temperature == 0.0:
# sample the most likely index
_, idx_next = torch.topk(probs, k=1, dim=-1)
# "sample" the single most likely index
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# sample from the distribution
# pluck the logits at the final step and scale by desired temperature
logits = logits / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)