tweaks and add a simple test
This commit is contained in:
@@ -288,19 +288,19 @@ class Transformer(nn.Module):
|
||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
logits = logits[:, -1, :] # crop to just the final time step
|
||||
if temperature == 0.0:
|
||||
# sample the most likely index
|
||||
_, idx_next = torch.topk(probs, k=1, dim=-1)
|
||||
# "sample" the single most likely index
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
else:
|
||||
# sample from the distribution
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
Reference in New Issue
Block a user