diff --git a/README.md b/README.md index ccd77c5..241c822 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - Rust - [llama2.rs](https://github.com/gaxler/llama2.rs) by @[gaxler](https://github.com/gaxler): a Rust port of this project - [llama2.rs](https://github.com/leo-du/llama2.rs) by @[leo-du](https://github.com/leo-du): A Rust port of this project + - [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project - Go - [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project - [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project @@ -236,9 +237,10 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - [llama2.java](https://github.com/mukel/llama2.java) by @[mukel](https://github.com/mukel): a Java port of this project - Kotlin - [llama2.kt](https://github.com/madroidmaq/llama2.kt) by @[madroidmaq](https://github.com/madroidmaq): a Kotlin port of this project +- Python + - [llama2.py](https://github.com/tairov/llama2.py) by @[tairov](https://github.com/tairov): a simple one file pure Python port of this project with zero dependencies - [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2 - ## unsorted todos - should calculate freq_cis online in the script run.c instead of loading them diff --git a/model.py b/model.py index 66304e7..f7edbb6 100644 --- a/model.py +++ b/model.py @@ -317,7 +317,7 @@ class Transformer(nn.Module): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) + logits = self(idx_cond) logits = logits[:, -1, :] # crop to just the final time step if temperature == 0.0: # "sample" the single most likely index