-
Bug
-
Resolution: Unresolved
-
P4
-
21, 25
-
x86_64
-
windows
ADDITIONAL SYSTEM INFORMATION :
Tested under Win11, Ubuntu 22.04, GraalVM and OpenJDK
A DESCRIPTION OF THE PROBLEM :
IndexOutOfBoundsException on working code after integrating new working code that used ByteBuffer instead of MemorySegment issue #11294 logged on GraalVM
STEPS TO FOLLOW TO REPRODUCE THE PROBLEM :
compile_rel.bat
jar_llama3.bat
chat_qwendb.bat
NOTE: must acquire qwen2-7b-instruct-q8_0.gguf from HuggingFace etc. for proper model, fails with other models as well. Works with Mistral and Llama.
EXPECTED VERSUS ACTUAL BEHAVIOR :
EXPECTED -
Iteration through buffer backed by MemorySegment works until unknown edge case causes failure.
ACTUAL -
C:\Progra~1\Java\graalvm-jdk-25+20.1\bin\java -server -XX:+UseParallelGC -Xmn26g -Xms26g -Xmx26g --enable-preview --add-modules jdk.incubator.vector -jar Llama3.jar --model qwen2-7b-instruct-q8_0.gguf --chat -n -1
[0.006s][warning][gc,ergo] NewSize (27262976k) is equal to or greater than initial heap size (27262976k). A new NewSize of 27262464k will be used to accomodate an old generation.
[0.006s][warning][gc,ergo] MaxNewSize (27262976k) is equal to or greater than the entire heap (27262976k). A new max generation size of 27262464k will be used.
WARNING: Using incubator modules: jdk.incubator.vector
Parse qwen2-7b-instruct-q8_0.gguf: 1159 millis
GGUF metadata:
{qwen2.block_count=28, tokenizer.ggml.add_bos_token=false, qwen2.embedding_length=3584, tokenizer.ggml.padding_token_id=151643, quantize.imatrix.chunks_count=1937, qwen2.feed_forward_length=18944, quantize.imatrix.entries_count=196, qwen2.attention.layer_norm_rms_epsilon=1.0E-6, tokenizer.ggml.merges=[Ljava.lang.String;@2e5d6d97, tokenizer.ggml.pre=qwen2, qwen2.attention.head_count_kv=4, general.architecture=qwen2, tokenizer.chat_template={% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant.<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}, general.file_type=7, general.name=qwen2-7b-instruct, general.quantization_version=2, tokenizer.ggml.token_type=[I@238e0d81, tokenizer.ggml.eos_token_id=151645, tokenizer.ggml.bos_token_id=151643, quantize.imatrix.dataset=../sft_2406.txt, qwen2.rope.freq_base=1000000.0, tokenizer.ggml.tokens=[Ljava.lang.String;@31221be2, tokenizer.ggml.model=gpt2, qwen2.context_length=32768, quantize.imatrix.file=../Qwen2/gguf/qwen2-7b-imatrix/imatrix.dat, qwen2.attention.head_count=28}
Tensor:blk.11.attn_q.bias=blk.11.attn_q.bias offset:2575927296 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_norm.weight=blk.3.attn_norm.weight offset:1322035200 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_v.weight=blk.19.attn_v.weight offset:5530279936 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.10.attn_q.bias=blk.10.attn_q.bias offset:2328268800 dims:[3584] number elems:3584 size:14336
Tensor:blk.12.attn_q.bias=blk.12.attn_q.bias offset:2823585792 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_q.bias=blk.13.attn_q.bias offset:3071244288 dims:[3584] number elems:3584 size:14336
Tensor:blk.8.attn_output.weight=blk.8.attn_output.weight offset:3872710656 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.15.attn_output.weight=blk.15.attn_output.weight offset:4512333824 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.11.attn_norm.weight=blk.11.attn_norm.weight offset:2343882752 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_q.bias=blk.19.attn_q.bias offset:5516615680 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_q.bias=blk.18.attn_q.bias offset:5268957184 dims:[3584] number elems:3584 size:14336
Tensor:blk.17.attn_q.bias=blk.17.attn_q.bias offset:5021298688 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_v.weight=blk.23.attn_v.weight offset:7099973632 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_output.weight=blk.0.attn_output.weight offset:797456384 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.10.ffn_norm.weight=blk.10.ffn_norm.weight offset:2312654848 dims:[3584] number elems:3584 size:14336
Tensor:blk.15.attn_q.bias=blk.15.attn_q.bias offset:4525981696 dims:[3584] number elems:3584 size:14336
Tensor:blk.14.attn_q.bias=blk.14.attn_q.bias offset:3174596608 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.attn_q.bias=blk.16.attn_q.bias offset:4773640192 dims:[3584] number elems:3584 size:14336
Tensor:blk.11.attn_q.weight=blk.11.attn_q.weight offset:2575941632 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.ffn_norm.weight=blk.16.ffn_norm.weight offset:4758026240 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.ffn_norm.weight=blk.13.ffn_norm.weight offset:3055630336 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.attn_q.weight=blk.26.attn_q.weight offset:7829299200 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.11.attn_k.weight=blk.11.attn_k.weight offset:2560329728 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.16.attn_k.weight=blk.16.attn_k.weight offset:4758042624 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.14.attn_v.weight=blk.14.attn_v.weight offset:3188260864 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.16.attn_q.weight=blk.16.attn_q.weight offset:4773654528 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_q.weight=blk.17.attn_q.weight offset:5021313024 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_q.bias=blk.22.attn_q.bias offset:6187423744 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.ffn_down.weight=blk.26.ffn_down.weight offset:7597254656 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.8.attn_norm.weight=blk.8.attn_norm.weight offset:3654313984 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_q.bias=blk.9.attn_q.bias offset:4134017024 dims:[3584] number elems:3584 size:14336
Tensor:blk.21.attn_q.bias=blk.21.attn_q.bias offset:6011932672 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_q.bias=blk.23.attn_q.bias offset:7086309376 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_k.bias=blk.23.attn_k.bias offset:7070709760 dims:[512] number elems:512 size:2048
Tensor:blk.25.attn_k.bias=blk.25.attn_k.bias offset:7566026752 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_q.bias=blk.8.attn_q.bias offset:3886358528 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.attn_q.bias=blk.20.attn_q.bias offset:5764274176 dims:[3584] number elems:3584 size:14336
Tensor:output.weight=output.weight offset:6203037696 dims:[3584, 152064] number elems:544997376 size:579059712
Tensor:blk.24.attn_q.bias=blk.24.attn_q.bias offset:7333967872 dims:[3584] number elems:3584 size:14336
Tensor:blk.17.ffn_up.weight=blk.17.ffn_up.weight offset:4933545984 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.3.attn_q.weight=blk.3.attn_q.weight offset:1554094080 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.8.ffn_up.weight=blk.8.ffn_up.weight offset:3798605824 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.22.attn_k.bias=blk.22.attn_k.bias offset:6171824128 dims:[512] number elems:512 size:2048
Tensor:blk.26.attn_k.bias=blk.26.attn_k.bias offset:7813685248 dims:[512] number elems:512 size:2048
Tensor:blk.7.attn_k.bias=blk.7.attn_k.bias offset:3623100416 dims:[512] number elems:512 size:2048
Tensor:blk.9.attn_k.bias=blk.9.attn_k.bias offset:4118417408 dims:[512] number elems:512 size:2048
Tensor:blk.20.ffn_down.weight=blk.20.ffn_down.weight offset:5532243968 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.21.attn_k.bias=blk.21.attn_k.bias offset:5996333056 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_down.weight=blk.23.ffn_down.weight offset:6854279168 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.27.attn_k.bias=blk.27.attn_k.bias offset:8061343744 dims:[512] number elems:512 size:2048
Tensor:blk.11.ffn_gate.weight=blk.11.ffn_gate.weight offset:2416035840 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.ffn_gate.weight=blk.14.ffn_gate.weight offset:3086858240 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_q.weight=blk.25.attn_q.weight offset:7581640704 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.8.attn_k.bias=blk.8.attn_k.bias offset:3870758912 dims:[512] number elems:512 size:2048
Tensor:blk.17.ffn_gate.weight=blk.17.ffn_gate.weight offset:4861407232 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.20.attn_k.bias=blk.20.attn_k.bias offset:5748674560 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_q.bias=blk.2.attn_q.bias offset:1306421248 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_q.bias=blk.1.attn_q.bias offset:1058762752 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.attn_v.weight=blk.7.attn_v.weight offset:3652364288 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_q.bias=blk.0.attn_q.bias offset:811104256 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_q.weight=blk.4.attn_q.weight offset:1801752576 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_norm.weight=blk.16.attn_norm.weight offset:4541595648 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.ffn_norm.weight=blk.19.ffn_norm.weight offset:5501001728 dims:[3584] number elems:3584 size:14336
Tensor:blk.21.attn_k.weight=blk.21.attn_k.weight offset:5996335104 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.20.ffn_up.weight=blk.20.ffn_up.weight offset:5676521472 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_q.bias=blk.25.attn_q.bias offset:7581626368 dims:[3584] number elems:3584 size:14336
Tensor:blk.27.attn_q.bias=blk.27.attn_q.bias offset:8076943360 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_k.weight=blk.9.attn_k.weight offset:4118419456 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.22.ffn_norm.weight=blk.22.ffn_norm.weight offset:6854250496 dims:[3584] number elems:3584 size:14336
Tensor:blk.24.attn_k.bias=blk.24.attn_k.bias offset:7318368256 dims:[512] number elems:512 size:2048
Tensor:blk.25.ffn_norm.weight=blk.25.ffn_norm.weight offset:7566012416 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.attn_q.bias=blk.26.attn_q.bias offset:7829284864 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.ffn_up.weight=blk.16.ffn_up.weight offset:4685887488 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_gate.weight=blk.5.ffn_gate.weight offset:1889505280 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.8.ffn_gate.weight=blk.8.ffn_gate.weight offset:3726467072 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.2.ffn_gate.weight=blk.2.ffn_gate.weight offset:1146529792 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.21.attn_norm.weight=blk.21.attn_norm.weight offset:5779888128 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.ffn_norm.weight=blk.1.ffn_norm.weight offset:1043148800 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_k.bias=blk.1.attn_k.bias offset:1043163136 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_v.weight=blk.6.attn_v.weight offset:2094274560 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_k.bias=blk.0.attn_k.bias offset:795504640 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_k.weight=blk.8.attn_k.weight offset:3870760960 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.21.ffn_up.weight=blk.21.ffn_up.weight offset:5924179968 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_norm.weight=blk.4.ffn_norm.weight offset:1786124288 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_q.bias=blk.3.attn_q.bias offset:1554079744 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_k.bias=blk.3.attn_k.bias offset:1538480128 dims:[512] number elems:512 size:2048
Tensor:blk.5.attn_k.bias=blk.5.attn_k.bias offset:2033797120 dims:[512] number elems:512 size:2048
Tensor:blk.9.ffn_up.weight=blk.9.ffn_up.weight offset:4046264320 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.attn_q.bias=blk.4.attn_q.bias offset:1801738240 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.ffn_norm.weight=blk.7.ffn_norm.weight offset:3623086080 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.attn_k.bias=blk.2.attn_k.bias offset:1290821632 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_k.bias=blk.6.attn_k.bias offset:2065010688 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_q.bias=blk.6.attn_q.bias offset:2080610304 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.attn_q.bias=blk.5.attn_q.bias offset:2049396736 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.attn_q.bias=blk.7.attn_q.bias offset:3638700032 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.attn_k.weight=blk.20.attn_k.weight offset:5748676608 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.4.attn_k.bias=blk.4.attn_k.bias offset:1786138624 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_gate.weight=blk.23.ffn_gate.weight offset:6926417920 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.ffn_down.weight=blk.11.ffn_down.weight offset:2343897088 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.10.ffn_up.weight=blk.10.ffn_up.weight offset:2240516096 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.12.attn_k.bias=blk.12.attn_k.bias offset:2807986176 dims:[512] number elems:512 size:2048
Tensor:blk.14.attn_k.bias=blk.14.attn_k.bias offset:3158996992 dims:[512] number elems:512 size:2048
Tensor:blk.14.ffn_down.weight=blk.14.ffn_down.weight offset:4149645312 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.26.ffn_gate.weight=blk.26.ffn_gate.weight offset:7669393408 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.attn_v.weight=blk.27.attn_v.weight offset:8090607616 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_v.weight=blk.5.attn_v.weight offset:2063060992 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.10.attn_q.weight=blk.10.attn_q.weight offset:2328283136 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_norm.weight=blk.19.attn_norm.weight offset:5284571136 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.ffn_gate.weight=blk.20.ffn_gate.weight offset:5604382720 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.10.attn_k.bias=blk.10.attn_k.bias offset:2312669184 dims:[512] number elems:512 size:2048
Tensor:blk.13.attn_output.weight=blk.13.attn_output.weight offset:3057596416 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_k.bias=blk.16.attn_k.bias offset:4758040576 dims:[512] number elems:512 size:2048
Tensor:blk.18.attn_k.bias=blk.18.attn_k.bias offset:5253357568 dims:[512] number elems:512 size:2048
Tensor:blk.27.attn_norm.weight=blk.27.attn_norm.weight offset:7844898816 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_v.bias=blk.18.attn_v.bias offset:5282619392 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_norm.weight=blk.2.attn_norm.weight offset:1074376704 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.attn_k.weight=blk.2.attn_k.weight offset:1290823680 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.ffn_up.weight=blk.27.ffn_up.weight offset:7989190656 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.7.attn_k.weight=blk.7.attn_k.weight offset:3623102464 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.ffn_down.weight=blk.17.ffn_down.weight offset:4789268480 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.0.attn_q.weight=blk.0.attn_q.weight offset:811118592 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.ffn_up.weight=blk.22.ffn_up.weight offset:6099685376 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.0.attn_v.weight=blk.0.attn_v.weight offset:824768512 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_k.weight=blk.15.attn_k.weight offset:4510384128 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_v.weight=blk.15.attn_v.weight offset:4539645952 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_q.weight=blk.15.attn_q.weight offset:4525996032 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.24.ffn_up.weight=blk.24.ffn_up.weight offset:7246215168 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.8.ffn_down.weight=blk.8.ffn_down.weight offset:3654328320 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.2.attn_q.weight=blk.2.attn_q.weight offset:1306435584 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.18.attn_v.weight=blk.18.attn_v.weight offset:5282621440 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_v.weight=blk.13.attn_v.weight offset:3084908544 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.attn_norm.weight=blk.17.attn_norm.weight offset:4789254144 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.attn_norm.weight=blk.0.attn_norm.weight offset:579059712 dims:[3584] number elems:3584 size:14336
Tensor:blk.6.attn_norm.weight=blk.6.attn_norm.weight offset:3190210560 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_v.bias=blk.9.attn_v.bias offset:4147679232 dims:[512] number elems:512 size:2048
Tensor:blk.21.attn_v.bias=blk.21.attn_v.bias offset:6025594880 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_v.bias=blk.10.attn_v.bias offset:2341931008 dims:[512] number elems:512 size:2048
Tensor:blk.25.attn_norm.weight=blk.25.attn_norm.weight offset:7349581824 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.attn_k.weight=blk.5.attn_k.weight offset:2033799168 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.23.attn_v.bias=blk.23.attn_v.bias offset:7099971584 dims:[512] number elems:512 size:2048
Tensor:blk.12.attn_v.bias=blk.12.attn_v.bias offset:2837248000 dims:[512] number elems:512 size:2048
Tensor:blk.27.attn_q.weight=blk.27.attn_q.weight offset:8076957696 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.13.ffn_up.weight=blk.13.ffn_up.weight offset:2983491584 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_down.weight=blk.5.ffn_down.weight offset:1817366528 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.27.attn_v.bias=blk.27.attn_v.bias offset:8090605568 dims:[512] number elems:512 size:2048
Tensor:blk.16.attn_v.bias=blk.16.attn_v.bias offset:4787302400 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_norm.weight=blk.22.attn_norm.weight offset:6782097408 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.attn_v.bias=blk.25.attn_v.bias offset:7595288576 dims:[512] number elems:512 size:2048
Tensor:blk.2.ffn_down.weight=blk.2.ffn_down.weight offset:1074391040 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.14.attn_v.bias=blk.14.attn_v.bias offset:3188258816 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_output.weight=blk.10.attn_output.weight offset:2314620928 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.attn_v.weight=blk.3.attn_v.weight offset:1567744000 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_output.weight=blk.12.attn_output.weight offset:2809937920 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.12.attn_q.weight=blk.12.attn_q.weight offset:2823600128 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.14.attn_norm.weight=blk.14.attn_norm.weight offset:4149630976 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_up.weight=blk.25.ffn_up.weight offset:7493873664 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.attn_output.weight=blk.14.attn_output.weight offset:3160948736 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.4.attn_k.weight=blk.4.attn_k.weight offset:1786140672 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.ffn_up.weight=blk.12.ffn_up.weight offset:2735833088 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.16.attn_output.weight=blk.16.attn_output.weight offset:4759992320 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_v.weight=blk.17.attn_v.weight offset:5034962944 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_q.weight=blk.13.attn_q.weight offset:3071258624 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_k.weight=blk.17.attn_k.weight offset:5005701120 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.7.attn_v.bias=blk.7.attn_v.bias offset:3652362240 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_v.weight=blk.2.attn_v.weight offset:1320085504 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_v.bias=blk.5.attn_v.bias offset:2063058944 dims:[512] number elems:512 size:2048
Tensor:blk.3.attn_v.bias=blk.3.attn_v.bias offset:1567741952 dims:[512] number elems:512 size:2048
Tensor:blk.23.attn_output.weight=blk.23.attn_output.weight offset:7072661504 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_output.weight=blk.22.attn_output.weight offset:6173775872 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.24.attn_output.weight=blk.24.attn_output.weight offset:7320320000 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.25.attn_output.weight=blk.25.attn_output.weight offset:7567978496 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_output.weight=blk.19.attn_output.weight offset:5502967808 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.attn_output.weight=blk.26.attn_output.weight offset:7815636992 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.ffn_up.weight=blk.26.ffn_up.weight offset:7741532160 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.1.attn_v.bias=blk.1.attn_v.bias offset:1072424960 dims:[512] number elems:512 size:2048
Tensor:blk.11.ffn_up.weight=blk.11.ffn_up.weight offset:2488174592 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.17.attn_output.weight=blk.17.attn_output.weight offset:5007650816 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.18.attn_output.weight=blk.18.attn_output.weight offset:5255309312 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_k.weight=blk.19.attn_k.weight offset:5501018112 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.14.attn_q.weight=blk.14.attn_q.weight offset:3174610944 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_v.weight=blk.16.attn_v.weight offset:4787304448 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.18.attn_k.weight=blk.18.attn_k.weight offset:5253359616 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.3.attn_k.weight=blk.3.attn_k.weight offset:1538482176 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.attn_output.weight=blk.27.attn_output.weight offset:8063295488 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.21.attn_output.weight=blk.21.attn_output.weight offset:5998284800 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.20.attn_output.weight=blk.20.attn_output.weight offset:5750626304 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.1.attn_v.weight=blk.1.attn_v.weight offset:1072427008 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.ffn_up.weight=blk.1.ffn_up.weight offset:971010048 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.ffn_norm.weight=blk.24.ffn_norm.weight offset:7318353920 dims:[3584] number elems:3584 size:14336
Tensor:blk.11.attn_output.weight=blk.11.attn_output.weight offset:2562279424 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.23.ffn_up.weight=blk.23.ffn_up.weight offset:6998556672 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.23.attn_k.weight=blk.23.attn_k.weight offset:7070711808 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.19.ffn_gate.weight=blk.19.ffn_gate.weight offset:5356724224 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.ffn_norm.weight=blk.27.ffn_norm.weight offset:8061329408 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_output.weight=blk.4.attn_output.weight offset:1788090368 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_q.weight=blk.6.attn_q.weight offset:2080624640 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_k.weight=blk.6.attn_k.weight offset:2065012736 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.21.ffn_norm.weight=blk.21.ffn_norm.weight offset:5996318720 dims:[3584] number elems:3584 size:14336
Tensor:blk.6.ffn_up.weight=blk.6.ffn_up.weight offset:3334502400 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.19.attn_q.weight=blk.19.attn_q.weight offset:5516630016 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.9.ffn_down.weight=blk.9.ffn_down.weight offset:3901986816 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.6.ffn_down.weight=blk.6.ffn_down.weight offset:3190224896 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.13.ffn_gate.weight=blk.13.ffn_gate.weight offset:2911352832 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.10.ffn_gate.weight=blk.10.ffn_gate.weight offset:2168377344 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.ffn_up.weight=blk.14.ffn_up.weight offset:4221784064 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.21.ffn_down.weight=blk.21.ffn_down.weight offset:5779902464 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.24.ffn_down.weight=blk.24.ffn_down.weight offset:7101937664 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.23.attn_q.weight=blk.23.attn_q.weight offset:7086323712 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.27.ffn_down.weight=blk.27.ffn_down.weight offset:7844913152 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.16.ffn_gate.weight=blk.16.ffn_gate.weight offset:4613748736 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.attn_norm.weight=blk.5.attn_norm.weight offset:1817352192 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_norm.weight=blk.1.attn_norm.weight offset:826718208 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_v.weight=blk.4.attn_v.weight offset:1815402496 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.18.ffn_down.weight=blk.18.ffn_down.weight offset:5036926976 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.6.ffn_norm.weight=blk.6.ffn_norm.weight offset:3406641152 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.ffn_norm.weight=blk.9.ffn_norm.weight offset:4118403072 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.ffn_norm.weight=blk.3.ffn_norm.weight offset:1538465792 dims:[3584] number elems:3584 size:14336
Tensor:blk.22.ffn_gate.weight=blk.22.ffn_gate.weight offset:6027546624 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.15.ffn_down.weight=blk.15.ffn_down.weight offset:4293951488 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.12.ffn_down.weight=blk.12.ffn_down.weight offset:2591555584 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.0.ffn_norm.weight=blk.0.ffn_norm.weight offset:795490304 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_k.weight=blk.1.attn_k.weight offset:1043165184 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.attn_q.weight=blk.1.attn_q.weight offset:1058777088 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.7.ffn_gate.weight=blk.7.ffn_gate.weight offset:3478808576 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_gate.weight=blk.4.ffn_gate.weight offset:1641846784 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.1.ffn_gate.weight=blk.1.ffn_gate.weight offset:898871296 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_v.weight=blk.25.attn_v.weight offset:7595290624 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_v.weight=blk.12.attn_v.weight offset:2837250048 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.ffn_norm.weight=blk.12.ffn_norm.weight offset:2807971840 dims:[3584] number elems:3584 size:14336
Tensor:blk.15.ffn_norm.weight=blk.15.ffn_norm.weight offset:4510367744 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.ffn_norm.weight=blk.18.ffn_norm.weight offset:5253343232 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_gate.weight=blk.25.ffn_gate.weight offset:7421734912 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.9.attn_norm.weight=blk.9.attn_norm.weight offset:3901972480 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_k.weight=blk.13.attn_k.weight offset:3055646720 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_k.weight=blk.0.attn_k.weight offset:795506688 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.24.attn_norm.weight=blk.24.attn_norm.weight offset:7101923328 dims:[3584] number elems:3584 size:14336
Tensor:blk.24.attn_q.weight=blk.24.attn_q.weight offset:7333982208 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.ffn_down.weight=blk.3.ffn_down.weight offset:1322049536 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.18.attn_q.weight=blk.18.attn_q.weight offset:5268971520 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.0.ffn_down.weight=blk.0.ffn_down.weight offset:579074048 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.11.attn_v.weight=blk.11.attn_v.weight offset:2589591552 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.26.attn_v.weight=blk.26.attn_v.weight offset:7842949120 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.ffn_up.weight=blk.0.ffn_up.weight offset:723351552 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.13.attn_norm.weight=blk.13.attn_norm.weight offset:2839199744 dims:[3584] number elems:3584 size:14336
Tensor:blk.14.attn_k.weight=blk.14.attn_k.weight offset:3158999040 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.ffn_up.weight=blk.15.ffn_up.weight offset:4438228992 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.attn_k.bias=blk.11.attn_k.bias offset:2560327680 dims:[512] number elems:512 size:2048
Tensor:blk.15.attn_k.bias=blk.15.attn_k.bias offset:4510382080 dims:[512] number elems:512 size:2048
Tensor:blk.17.attn_v.bias=blk.17.attn_v.bias offset:5034960896 dims:[512] number elems:512 size:2048
Tensor:blk.4.attn_norm.weight=blk.4.attn_norm.weight offset:1569693696 dims:[3584] number elems:3584 size:14336
Tensor:blk.10.attn_norm.weight=blk.10.attn_norm.weight offset:2096224256 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_v.bias=blk.19.attn_v.bias offset:5530277888 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_output.weight=blk.6.attn_output.weight offset:2066962432 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.10.attn_k.weight=blk.10.attn_k.weight offset:2312671232 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.attn_k.bias=blk.17.attn_k.bias offset:5005699072 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_v.weight=blk.10.attn_v.weight offset:2341933056 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.attn_output.weight=blk.2.attn_output.weight offset:1292773376 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.20.attn_norm.weight=blk.20.attn_norm.weight offset:5532229632 dims:[3584] number elems:3584 size:14336
Tensor:blk.8.ffn_norm.weight=blk.8.ffn_norm.weight offset:3870744576 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.ffn_norm.weight=blk.2.ffn_norm.weight offset:1290807296 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.ffn_norm.weight=blk.5.ffn_norm.weight offset:2033782784 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_k.bias=blk.13.attn_k.bias offset:3055644672 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_k.weight=blk.22.attn_k.weight offset:6171826176 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_norm.weight=blk.12.attn_norm.weight offset:2591541248 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.ffn_up.weight=blk.7.ffn_up.weight offset:3550947328 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.attn_k.weight=blk.27.attn_k.weight offset:8061345792 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.8.attn_v.bias=blk.8.attn_v.bias offset:3900020736 dims:[512] number elems:512 size:2048
Tensor:blk.20.attn_v.bias=blk.20.attn_v.bias offset:5777936384 dims:[512] number elems:512 size:2048
Tensor:blk.11.attn_v.bias=blk.11.attn_v.bias offset:2589589504 dims:[512] number elems:512 size:2048
Tensor:blk.5.attn_q.weight=blk.5.attn_q.weight offset:2049411072 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.12.attn_k.weight=blk.12.attn_k.weight offset:2807988224 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_v.bias=blk.13.attn_v.bias offset:3084906496 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_q.weight=blk.22.attn_q.weight offset:6187438080 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_v.bias=blk.22.attn_v.bias offset:6201085952 dims:[512] number elems:512 size:2048
Tensor:blk.24.attn_v.bias=blk.24.attn_v.bias offset:7347630080 dims:[512] number elems:512 size:2048
Tensor:blk.24.attn_v.weight=blk.24.attn_v.weight offset:7347632128 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.19.attn_k.bias=blk.19.attn_k.bias offset:5501016064 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_v.weight=blk.8.attn_v.weight offset:3900022784 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.ffn_up.weight=blk.2.ffn_up.weight offset:1218668544 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.26.attn_v.bias=blk.26.attn_v.bias offset:7842947072 dims:[512] number elems:512 size:2048
Tensor:blk.20.attn_v.weight=blk.20.attn_v.weight offset:5777938432 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_v.bias=blk.15.attn_v.bias offset:4539643904 dims:[512] number elems:512 size:2048
Tensor:blk.1.ffn_down.weight=blk.1.ffn_down.weight offset:826732544 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.8.attn_q.weight=blk.8.attn_q.weight offset:3886372864 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.15.attn_norm.weight=blk.15.attn_norm.weight offset:4293937152 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.ffn_gate.weight=blk.9.ffn_gate.weight offset:3974125568 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_down.weight=blk.4.ffn_down.weight offset:1569708032 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.7.ffn_down.weight=blk.7.ffn_down.weight offset:3406669824 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:token_embd.weight=token_embd.weight offset:0 dims:[3584, 152064] number elems:544997376 size:579059712
Tensor:blk.3.attn_output.weight=blk.3.attn_output.weight offset:1540431872 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.attn_k.weight=blk.26.attn_k.weight offset:7813687296 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_output.weight=blk.5.attn_output.weight offset:2035748864 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.9.attn_output.weight=blk.9.attn_output.weight offset:4120369152 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.7.attn_output.weight=blk.7.attn_output.weight offset:3625052160 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.23.attn_norm.weight=blk.23.attn_norm.weight offset:6854264832 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.attn_k.weight=blk.25.attn_k.weight offset:7566028800 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.attn_output.weight=blk.1.attn_output.weight offset:1045114880 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.ffn_gate.weight=blk.3.ffn_gate.weight offset:1394188288 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.6.ffn_gate.weight=blk.6.ffn_gate.weight offset:3262363648 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.3.ffn_up.weight=blk.3.ffn_up.weight offset:1466327040 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.18.ffn_up.weight=blk.18.ffn_up.weight offset:5181204480 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.9.attn_q.weight=blk.9.attn_q.weight offset:4134031360 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_v.bias=blk.6.attn_v.bias offset:2094272512 dims:[512] number elems:512 size:2048
Tensor:blk.21.attn_q.weight=blk.21.attn_q.weight offset:6011947008 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.21.attn_v.weight=blk.21.attn_v.weight offset:6025596928 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.ffn_gate.weight=blk.15.ffn_gate.weight offset:4366090240 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.attn_v.bias=blk.4.attn_v.bias offset:1815400448 dims:[512] number elems:512 size:2048
Tensor:blk.9.attn_v.weight=blk.9.attn_v.weight offset:4147681280 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.20.ffn_norm.weight=blk.20.ffn_norm.weight offset:5748660224 dims:[3584] number elems:3584 size:14336
Tensor:blk.12.ffn_gate.weight=blk.12.ffn_gate.weight offset:2663694336 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.7.attn_q.weight=blk.7.attn_q.weight offset:3638714368 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:output_norm.weight=output_norm.weight offset:8092557312 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.attn_v.bias=blk.0.attn_v.bias offset:824766464 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_norm.weight=blk.23.ffn_norm.weight offset:7070695424 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.ffn_gate.weight=blk.0.ffn_gate.weight offset:651212800 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_up.weight=blk.5.ffn_up.weight offset:1961644032 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.attn_k.weight=blk.24.attn_k.weight offset:7318370304 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.attn_v.bias=blk.2.attn_v.bias offset:1320083456 dims:[512] number elems:512 size:2048
Tensor:blk.19.ffn_up.weight=blk.19.ffn_up.weight offset:5428862976 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_up.weight=blk.4.ffn_up.weight offset:1713985536 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.ffn_norm.weight=blk.11.ffn_norm.weight offset:2560313344 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.ffn_norm.weight=blk.26.ffn_norm.weight offset:7813670912 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.ffn_down.weight=blk.13.ffn_down.weight offset:2839214080 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.14.ffn_norm.weight=blk.14.ffn_norm.weight offset:4293922816 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.ffn_down.weight=blk.16.ffn_down.weight offset:4541609984 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.7.attn_norm.weight=blk.7.attn_norm.weight offset:3406655488 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_norm.weight=blk.18.attn_norm.weight offset:5036912640 dims:[3584] number elems:3584 size:14336
Tensor:blk.10.ffn_down.weight=blk.10.ffn_down.weight offset:2096238592 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.17.ffn_norm.weight=blk.17.ffn_norm.weight offset:5005684736 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.ffn_down.weight=blk.19.ffn_down.weight offset:5284585472 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.26.attn_norm.weight=blk.26.attn_norm.weight offset:7597240320 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_down.weight=blk.25.ffn_down.weight offset:7349596160 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.18.ffn_gate.weight=blk.18.ffn_gate.weight offset:5109065728 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.20.attn_q.weight=blk.20.attn_q.weight offset:5764288512 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_v.weight=blk.22.attn_v.weight offset:6201088000 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.ffn_gate.weight=blk.27.ffn_gate.weight offset:7917051904 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.22.ffn_down.weight=blk.22.ffn_down.weight offset:6782111744 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.21.ffn_gate.weight=blk.21.ffn_gate.weight offset:5852041216 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.ffn_gate.weight=blk.24.ffn_gate.weight offset:7174076416 dims:[3584, 18944] number elems:67895296 size:72138752
Load model: 332 millis
>hi
setFloat:0 of size:3584
setFloat:1 of size:3584
setFloat:2 of size:3584
setFloat:...SNIP
Exception in thread "main" java.lang.IndexOutOfBoundsException: Out of bound access on segment MemorySegment{ kind: mapped, address: 0x217e96b3f20, byteSize: 2048 }; new offset = 2048; new length = 4
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.outOfBoundException(AbstractMemorySegmentImpl.java:433)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.apply(AbstractMemorySegmentImpl.java:414)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.apply(AbstractMemorySegmentImpl.java:70)
at java.base/jdk.internal.util.Preconditions.outOfBounds(Preconditions.java:98)
at java.base/jdk.internal.util.Preconditions.outOfBoundsCheckIndex(Preconditions.java:124)
at java.base/jdk.internal.util.Preconditions.checkIndex(Preconditions.java:448)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkBounds(AbstractMemorySegmentImpl.java:403)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkAccess(AbstractMemorySegmentImpl.java:357)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkEnclosingLayout(AbstractMemorySegmentImpl.java:362)
at java.base/java.lang.invoke.SegmentVarHandle.checkSegment(SegmentVarHandle.java:92)
at java.base/java.lang.invoke.VarHandleSegmentAsFloats.get(VarHandleSegmentAsFloats.java:59)
at java.base/java.lang.invoke.VarHandleSegmentAsFloats.get(VarHandleSegmentAsFloats.java:53)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.get(AbstractMemorySegmentImpl.java:754)
at com.llama4j.FloatTensor.readFloat(Llama3.java:2174)
at com.llama4j.F32FloatTensor.getFloat(Llama3.java:2946)
at com.llama4j.FloatTensor.lambda$addInPlace$0(Llama3.java:2320)
at com.llama4j.FloatTensor.mapWithIndexInPlace(Llama3.java:2314)
at com.llama4j.FloatTensor.addInPlace(Llama3.java:2320)
at com.llama4j.FloatTensor.addInPlace(Llama3.java:2324)
at com.llama4j.Llama.forwardQwen(Llama3.java:1392)
at com.llama4j.Llama.generateTokensQwen(Llama3.java:1596)
at com.llama4j.Llama3.runInteractive(Llama3.java:149)
at com.llama4j.Llama3.main(Llama3.java:338)
---------- BEGIN SOURCE ----------
///usr/bin/env jbang "$0" "$@" ; exit $?
//JAVA 21+
//PREVIEW
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0
//MAIN com.llama4j.Llama3
// Practical Llama 3 (and 3.1) inference in a single Java file
// Author: Alfonso² Peterssen
// Based on Andrej Karpathy's llama2.c and minbpe projects
//
// Supports llama.cpp's GGUF format, restricted to Q4_0 and Q8_0 quantized models
// Multi-threaded matrix vector multiplication routines implemented using Java's Vector API
// Simple CLI with --chat and --instruct mode
//
// To run just:
// jbang Llama3.java --help
//
// Remember: Llama models use GPT2 vocabulary while non-Llama models use Llama vocabulary!
// Enjoy!
package com.llama4j;
import jdk.incubator.vector.*;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import java.util.function.LongConsumer;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
public class Llama3 {
// Batch-size used in prompt evaluation.
private static int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
private final static boolean DEBUG = false;
static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
Sampler sampler;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
sampler = Sampler.ARGMAX;
} else {
// we sample from this distribution to get the next token
RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
Sampler innerSampler;
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
innerSampler = new CategoricalSampler(rng);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
innerSampler = new ToppSampler(vocabularySize, topp, rng);
}
sampler = logits -> {
// apply the temperature to the logits
logits.divideInPlace(0, logits.size(), temperature);
// apply softmax to the logits to get the probabilities for next token
logits.softmaxInPlace(0, logits.size());
return innerSampler.sampleToken(logits);
};
}
return sampler;
}
static void runInteractive(Llama model, Sampler sampler, Options options) {
Llama.State state = null;
List<Integer> conversationTokens = new ArrayList<>();
ChatFormatInterface chatFormat;
// Chat format seems solely based on individual model, so we extract a name in model loader from Metada general.name
if(ModelLoader.name.equals("mistral")) {
chatFormat = new MistralChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("llama")) {
chatFormat = new ChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("qwen")) {
BATCH_SIZE = 1;
chatFormat = new ChatMLFormat(model.tokenizer());
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
conversationTokens.add(chatFormat.getBeginOfText());
if (options.systemPrompt() != null) {
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
int startPosition = 0;
Scanner in = new Scanner(System.in);
loop: while (true) {
boolean storeDb = true;
System.out.print("> ");
System.out.flush();
String userText = in.nextLine();
switch (userText) {
case "/quit":
case "/exit": break loop;
case "/context": {
System.out.printf("%d out of %d context tokens used (%d tokens remaining)%n",
conversationTokens.size(),
options.maxTokens(),
options.maxTokens() - conversationTokens.size());
continue;
}
}
if (state == null) {
state = model.createNewState(BATCH_SIZE, chatFormat.getBeginOfText());
}
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set<Integer> stopTokens = chatFormat.getStopTokens();
List<Integer> responseTokens;
if(ModelLoader.name.equals("qwen")) {
responseTokens = Llama.generateTokensQwen(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
int tokenType = model.tokenizer().getTokenType(token);
if (tokenType == 1 || tokenType == 6) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
} else {
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
}
// Include stop token in the prompt history, but not in the response displayed to the user.
conversationTokens.addAll(responseTokens);
startPosition = conversationTokens.size();
Integer stopToken = null;
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
stopToken = responseTokens.getLast();
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
if (stopToken == null) {
System.err.println("Ran out of context length...");
break;
}
}
}
}
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
ChatFormatInterface chatFormat;
// Chat format seems solely based on individual model, so we extract a name in model loader from Metada general.name
if(ModelLoader.name.equals("mistral")) {
chatFormat = new MistralChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("llama")) {
chatFormat = new ChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("qwen")) {
chatFormat = new ChatMLFormat(model.tokenizer());
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
Llama.State state = model.createNewState(BATCH_SIZE, chatFormat.getBeginOfText());
List<Integer> promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.getBeginOfText());
if (options.systemPrompt() != null) {
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set<Integer> stopTokens = chatFormat.getStopTokens();
List<Integer> responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}
}
record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
String localNode, String remoteNode, int remotePort) {
static final int DEFAULT_MAX_TOKENS = 512;
Options {
require(modelPath != null, "Missing argument: --model <path> is required");
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
static void require(boolean condition, String messageFormat, Object... args) {
if (!condition) {
System.out.println("ERROR " + messageFormat.formatted(args));
System.out.println();
printUsage(System.out);
System.exit(-1);
}
}
static void printUsage(PrintStream out) {
out.println("Usage: jbang Llama3.java [options]");
out.println();
out.println("Options:");
out.println(" --model, -m <path> required, path to .gguf file");
out.println(" --interactive, --chat, -i run in chat mode");
out.println(" --instruct run in instruct (once) mode, default mode");
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --seed <long> random seed, default System.nanoTime()");
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println(" --localNode <string> local database client node");
out.println(" --remoteNode <string> remote database client node");
out.println(" --remotePort <int> remote database port");
out.println();
out.println("Examples:");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\"");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\"");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false");
}
static Options parseOptions(String[] args) {
String prompt = null;
String systemPrompt = null;
float temperature = 0.1f;
float topp = 0.95f;
Path modelPath = null;
long seed = System.nanoTime();
// Keep max context length small for low-memory devices.
int maxTokens = DEFAULT_MAX_TOKENS;
boolean interactive = false;
boolean stream = true;
boolean echo = false;
String localNode = null;
String remoteNode = null;
int remotePort = 0;
for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
}
default -> {
String nextArg;
if (optionName.contains("=")) {
String[] parts = optionName.split("=", 2);
optionName = parts[0];
nextArg = parts[1];
} else {
require(i + 1 < args.length, "Missing argument for option %s", optionName);
nextArg = args[i + 1];
i += 1; // skip arg
}
switch (optionName) {
case "--prompt", "-p" -> prompt = nextArg;
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
case "--top-p" -> topp = Float.parseFloat(nextArg);
case "--model", "-m" -> modelPath = Paths.get(nextArg);
case "--seed", "-s" -> seed = Long.parseLong(nextArg);
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
default -> require(false, "Unknown option: %s", optionName);
}
}
}
}
return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo, localNode, remoteNode, remotePort);
}
}
public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
if(model == null)
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed());
if (options.interactive()) {
runInteractive(model, sampler, options);
} else {
runInstructOnce(model, sampler, options);
}
}
}
final class GGUF {
private static final int GGUF_MAGIC = 0x46554747;
private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2
private static final List<Integer> SUPPORTED_GGUF_VERSIONS = List.of(2, 3);
private int magic;
private int version;
private int tensorCount; // uint64_t
private int alignment;
private int metadata_kv_count; // uint64_t
private Map<String, Object> metadata;
public Map<String, GGUFTensorInfo> getTensorInfos() {
return tensorInfos;
}
private Map<String, GGUFTensorInfo> tensorInfos;
private long tensorDataOffset;
public long getTensorDataOffset() {
return tensorDataOffset;
}
public Map<String, Object> getMetadata() {
return metadata;
}
private final ByteBuffer BB_1 = ByteBuffer.allocate(Byte.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_2 = ByteBuffer.allocate(Short.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_4 = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_8 = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN);
public static GGUF loadModel(Path modelPath) throws IOException {
try (FileChannel fileChannel = FileChannel.open(modelPath);
var ignored = Timer.log("Parse " + modelPath)) {
GGUF gguf = new GGUF();
gguf.loadModelImpl(fileChannel);
return gguf;
}
}
enum MetadataValueType {
// The value is a 8-bit unsigned integer.
UINT8(1),
// The value is a 8-bit signed integer.
INT8(1),
// The value is a 16-bit unsigned little-endian integer.
UINT16(2),
// The value is a 16-bit signed little-endian integer.
INT16(2),
// The value is a 32-bit unsigned little-endian integer.
UINT32(4),
// The value is a 32-bit signed little-endian integer.
INT32(4),
// The value is a 32-bit IEEE754 floating point number.
FLOAT32(4),
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
BOOL(1),
// The value is a UTF-8 non-null-terminated string, with length prepended.
STRING(-8),
// The value is an array of other values, with the length and type prepended.
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
ARRAY(-8),
// The value is a 64-bit unsigned little-endian integer.
UINT64(8),
// The value is a 64-bit signed little-endian integer.
INT64(8),
// The value is a 64-bit IEEE754 floating point number.
FLOAT64(8);
private final int byteSize;
MetadataValueType(int byteSize) {
this.byteSize = byteSize;
}
private static final MetadataValueType[] VALUES = values();
public static MetadataValueType fromIndex(int index) {
return VALUES[index];
}
public int byteSize() {
return byteSize;
}
}
private void loadModelImpl(FileChannel fileChannel) throws IOException {
// The header of the file.
readHeader(fileChannel); // gguf_header_t header;
// Tensor infos, which can be used to locate the tensor data.
// gguf_tensor_info_t tensor_infos[header.tensor_count];
this.tensorInfos = HashMap.newHashMap(tensorCount);
for (int i = 0; i < tensorCount; ++i) {
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
assert !tensorInfos.containsKey(ti.name);
tensorInfos.put(ti.name, ti);
}
// Padding to the nearest multiple of `ALIGNMENT`.
// uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)];
//long _padding = -fileChannel.position() & (ALIGNMENT - 1);
long _padding = getAlignment() - (fileChannel.position() % getAlignment());
fileChannel.position(fileChannel.position() + _padding);
// Tensor data.
//
// This is arbitrary binary data corresponding to the weights of the model. This data should be close
// or identical to the data in the original model file, but may be different due to quantization or
// other optimizations for inference. Any such deviations should be recorded in the metadata or as
// part of the architecture definition.
//
// Each tensor's data must be stored within this array, and located through its `tensor_infos` entry.
// The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors
// should be padded to `ALIGNMENT` bytes.
// uint8_t tensor_data[];
this.tensorDataOffset = fileChannel.position();
}
public static Map<String, GGMLTensorEntry> loadTensors(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
Arena arena = Arena.ofAuto();
MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena);
Map<String, GGMLTensorEntry> tensorEntries = HashMap.newHashMap(tensorInfos.size());
for (Map.Entry<String, GGUFTensorInfo> entry : tensorInfos.entrySet()) {
GGUFTensorInfo ti = entry.getValue();
int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
System.out.println("Tensor:"+entry.getKey()+"="+ti.name+" offset:"+ti.offset+" dims:"+Arrays.toString(ti.dimensions)+" number elems:"+numberOfElements+" size:"+sizeInBytes);
MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes);
tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
}
return tensorEntries;
}
public record GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) {
}
private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
int ggmlTypeId = readInt(fileChannel); // ggml_type type;
return GGMLType.fromId(ggmlTypeId);
}
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
// The name of the tensor. It is a standard GGUF string, with the caveat that
// it must be at most 64 bytes long.
String name = readString(fileChannel); // gguf_string_t name;
assert name.length() <= 64;
// The number of dimensions in the tensor.
// Currently at most 4, but this may change in the future.
int n_dimensions = readInt(fileChannel); // uint32_t n_dimensions;
assert n_dimensions <= 4;
// The dimensions of the tensor.
int[] dimensions = new int[n_dimensions]; // uint64_t dimensions[n_dimensions];
for (int i = 0; i < n_dimensions; ++i) {
dimensions[i] = Math.toIntExact(readLong(fileChannel));
}
// The type of the tensor.
GGMLType ggmlType = readGGMLType(fileChannel); // ggml_type type;
// The offset of the tensor's data in this file in bytes.
// This offset is relative to `tensor_data`, not to the start
// of the file, to make it easier for writers to write the file.
// Readers should consider exposing this offset relative to the
// file to make it easier to read the data.
// Must be a multiple of `ALIGNMENT`.
long offset = readLong(fileChannel); // uint64_t offset;
assert offset % getAlignment() == 0;
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
}
private String readString(FileChannel fileChannel) throws IOException {
// A string in GGUF.
// The length of the string, in bytes.
int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len;
// The string as a UTF-8 non-null-terminated string.
byte[] bytes = new byte[len]; // char string[len];
int bytesRead = fileChannel.read(ByteBuffer.wrap(bytes));
assert len == bytesRead;
return new String(bytes, StandardCharsets.UTF_8);
}
private Pair<String, Object> readKeyValuePair(FileChannel fileChannel) throws IOException {
// The key of the metadata. It is a standard GGUF string, with the following caveats:
// - It must be a valid ASCII string.
// - It must be a hierarchical key, where each segment is `lower_snake_case` and separated by a `.`.
// - It must be at most 2^16-1/65535 bytes long.
// Any keys that do not follow these rules are invalid.
String key = readString(fileChannel); // gguf_string_t key;
assert key.length() < (1 << 16);
assert key.codePoints().allMatch(cp -> ('a' <= cp && cp <= 'z') || ('0' <= cp && cp <= '9') || cp == '_' || cp == '.');
Object value = readMetadataValue(fileChannel);
return new Pair<>(key, value);
}
private Object readMetadataValue(FileChannel fileChannel) throws IOException {
// The type of the value.
// Must be one of the `gguf_metadata_value_type` values.
MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type value_type;
// The value.
return readMetadataValueOfType(value_type, fileChannel); // gguf_metadata_value_t value;
}
void readHeader(FileChannel fileChannel) throws IOException {
// Magic number to announce that this is a GGUF file.
// Must be `GGUF` at the byte level: `0x47` `0x47` `0x55` `0x46`.
// Your executor might do little-endian byte order, so it might be
// check for 0x46554747 and letting the endianness cancel out.
// Consider being *very* explicit about the byte order here.
this.magic = readInt(fileChannel); // uint32_t magic;
if (magic != GGUF_MAGIC) {
throw new IllegalArgumentException("unsupported header.magic " + magic);
}
// The version of the format implemented.
// Must be `3` for version described in this spec.
//
// This version should only be increased for structural changes to the format.
// Changes that do not affect the structure of the file should instead update the metadata
// to signify the change.
this.version = readInt(fileChannel); // uint32_t version;
if (!SUPPORTED_GGUF_VERSIONS.contains(version)) {
throw new IllegalArgumentException("unsupported header.version " + version);
}
// The number of tensors in the file.
// This is explicit, instead of being included in the metadata, to ensure it is always present
// for loading the tensors.
this.tensorCount = Math.toIntExact(readLong(fileChannel)); // uint64_t tensor_count;
// The number of metadata key-value pairs.
this.metadata_kv_count = Math.toIntExact(readLong(fileChannel)); // uint64_t metadata_kv_count;
// The metadata key-value pairs.
// gguf_metadata_kv_t metadata_kv[metadata_kv_count];
this.metadata = HashMap.newHashMap(metadata_kv_count);
for (int i = 0; i < metadata_kv_count; ++i) {
Pair<String, Object> keyValue = readKeyValuePair(fileChannel);
assert !metadata.containsKey(keyValue.first());
metadata.put(keyValue.first(), keyValue.second());
}
}
private Object readArray(FileChannel fileChannel) throws IOException {
// Any value type is valid, including arrays.
MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type type;
// Number of elements, not bytes
int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len;
// The array of values.
// gguf_metadata_value_t array[len];
switch (value_type) {
case UINT8, INT8 -> {
byte[] bytes = new byte[len];
for (int i = 0; i < len; ++i) {
bytes[i] = readByte(fileChannel);
}
return bytes;
}
case UINT16, INT16 -> {
short[] shorts = new short[len];
for (int i = 0; i < len; ++i) {
shorts[i] = readShort(fileChannel);
}
return shorts;
}
case UINT32, INT32 -> {
int[] ints = new int[len];
for (int i = 0; i < len; ++i) {
ints[i] = readInt(fileChannel);
}
return ints;
}
case FLOAT32 -> {
float[] floats = new float[len];
for (int i = 0; i < len; ++i) {
floats[i] = readFloat(fileChannel);
}
return floats;
}
case BOOL -> {
boolean[] booleans = new boolean[len];
for (int i = 0; i < len; ++i) {
booleans[i] = readBoolean(fileChannel);
}
return booleans;
}
case STRING -> {
String[] strings = new String[len];
for (int i = 0; i < len; ++i) {
strings[i] = readString(fileChannel);
}
return strings;
}
case ARRAY -> {
Object[] arrays = new Object[len];
for (int i = 0; i < len; ++i) {
arrays[i] = readArray(fileChannel);
}
return arrays;
}
default -> throw new UnsupportedOperationException("read array of " + value_type);
}
}
private Object readMetadataValueOfType(MetadataValueType valueType, FileChannel fileChannel) throws IOException {
return switch (valueType) {
case UINT8, INT8 -> readByte(fileChannel);
case UINT16, INT16 -> readShort(fileChannel);
case UINT32, INT32 -> readInt(fileChannel);
case FLOAT32 -> readFloat(fileChannel);
case UINT64, INT64 -> readLong(fileChannel);
case FLOAT64 -> readDouble(fileChannel);
case BOOL -> readBoolean(fileChannel);
case STRING -> readString(fileChannel);
case ARRAY -> readArray(fileChannel);
};
}
private byte readByte(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_1);
assert bytesRead == 1;
return BB_1.clear().get(0);
}
private boolean readBoolean(FileChannel fileChannel) throws IOException {
return readByte(fileChannel) != 0;
}
private short readShort(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_2);
assert bytesRead == 2;
return BB_2.clear().getShort(0);
}
private int readInt(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_4);
assert bytesRead == 4;
return BB_4.clear().getInt(0);
}
private long readLong(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_8);
assert bytesRead == 8;
return BB_8.clear().getLong(0);
}
private float readFloat(FileChannel fileChannel) throws IOException {
return Float.intBitsToFloat(readInt(fileChannel));
}
private double readDouble(FileChannel fileChannel) throws IOException {
return Double.longBitsToDouble(readLong(fileChannel));
}
private MetadataValueType readMetadataValueType(FileChannel fileChannel) throws IOException {
int index = readInt(fileChannel);
return MetadataValueType.fromIndex(index);
}
public int getAlignment() {
if (alignment != 0) {
return alignment;
}
alignment = (int) metadata.getOrDefault("general.alignment", DEFAULT_ALIGNMENT);
assert Integer.bitCount(alignment) == 1 : "alignment must be a power of two";
return alignment;
}
}
interface Timer extends AutoCloseable {
@Override
void close(); // no Exception
static Timer log(String label) {
return log(label, TimeUnit.MILLISECONDS);
}
static Timer log(String label, TimeUnit timeUnit) {
return new Timer() {
final long startNanos = System.nanoTime();
@Override
public void close() {
long elapsedNanos = System.nanoTime() - startNanos;
System.err.println(label + ": "
+ timeUnit.convert(elapsedNanos, TimeUnit.NANOSECONDS) + " "
+ timeUnit.toChronoUnit().name().toLowerCase());
}
};
}
}
/**
* Load model, get GGUF metadata, load vocabulary, create tokenizer, create config, if loadWeights - load tensors, load weights
* create Llama with config, tokenizer, weights
*/
final class ModelLoader {
static final String TOKENIZER_GPT2_MODEL = "gpt2"; // Llama3 uses gpt2!
static final String TOKENIZER_LLAMA_MODEL = "llama"; // non Llama uses llama!
public static String model = "gpt2"; // default for Llama models!
public static String name = null; // Name is based solely on name of model, they all seem to have their own ChatFormat not based on model
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
private static Vocabulary loadVocabulary(Map<String, Object> metadata) {
model = (String) metadata.get("tokenizer.ggml.model");
name = (String) metadata.get("general.name");
if(name.toLowerCase().contains("llama")) // Meta Llama etc. etc.
name = "llama";
else
if(name.toLowerCase().contains("mistral")) //models--mistralai etc. etc.
name="mistral";
else
if(name.toLowerCase().contains("qwen"))
name="qwen";
String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
if(TOKENIZER_LLAMA_MODEL.equals(model)) {
float[] scores = (float[]) metadata.get("tokenizer.ggml.scores");
return new Vocabulary(tokens, scores);
} else {
if(TOKENIZER_GPT2_MODEL.equals(model)) {
return new Vocabulary(tokens, null);
} else {
throw new IllegalArgumentException("expected " + TOKENIZER_GPT2_MODEL + " or "+ TOKENIZER_LLAMA_MODEL+ " but found " + model);
}
}
}
public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
GGUF gguf = GGUF.loadModel(ggufPath);
FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
return loadModel(fileChannel, gguf, contextLength, loadWeights);
}
public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException {
try (var ignored = Timer.log("Load model")) {
Map<String, Object> metadata = gguf.getMetadata();
System.out.println("GGUF metadata:\r\n"+metadata);
Vocabulary vocabulary = loadVocabulary(metadata);
TokenizerInterface tokenizer;
Llama.Configuration config;
Llama.Weights weights = null;
String arch = (String) metadata.get("general.architecture");
if(ModelLoader.name.equals("mistral")) {
tokenizer = createLlamaTokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadLlamaWeights(tensorEntries, config);
}
} else {
if(ModelLoader.name.equals("llama")) {
tokenizer = createGPT2Tokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadGPT2Weights(tensorEntries, config);
}
} else {
if(ModelLoader.name.equals("qwen")) {
tokenizer = createQwen2Tokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadQwenWeights(tensorEntries, config);
}
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
return new Llama(config, tokenizer, weights);
}
}
static Llama.Configuration createConfig(String arch, Map<String, Object> metadata, Vocabulary vocabulary, int contextLength) {
Llama.Configuration config = new Llama.Configuration(
(int) metadata.get(arch+".embedding_length"),
(int) metadata.get(arch+".feed_forward_length"),
(int) metadata.get(arch+".block_count"),
(int) metadata.get(arch+".attention.head_count"),
metadata.containsKey(arch+".attention.head_count_kv")
? (int) metadata.get(arch+".attention.head_count_kv")
: (int) metadata.get(arch+".attention.head_count"),
vocabulary.size(),
(int) metadata.get(arch+".context_length"),
(float) metadata.getOrDefault(arch+".attention.layer_norm_rms_epsilon", 1e-5f),
(float) metadata.getOrDefault(arch+".rope.freq_base", 10000f)
).withContextLength(contextLength);
return config;
}
/**
* Called from AOT.tryUsePreloaded and ModelLoader.loadModel
* @param tensorEntries
* @param config
* @return
*/
static Llama.Weights loadGPT2Weights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
float scaleFactor = 8;
float loFreqFactor = 1;
float hiFreqFactor = 3;
int oldContextLength = 8192;
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta,
ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
Llama.Weights qw = new Llama.Weights(
loadQuantized(tokenEmbeddings),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
// If "output.weight" is not present then the embedding weights are tied/shared with the decoder.
// This is commonly referred as "tie word embeddings".
loadQuantized(tensorEntries.getOrDefault("output.weight", tokenEmbeddings))
);
return qw;
}
static Llama.Weights loadLlamaWeights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
Llama.Weights qw = new Llama.Weights(
loadQuantized(tensorEntries.get("token_embd.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
loadQuantized(tensorEntries.get("output.weight"))
);
return qw;
}
static Llama.Weights loadQwenWeights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
Llama.Weights qw = new Llama.Weights(
loadQuantized(tensorEntries.get("token_embd.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
loadQuantized(tensorEntries.get("output.weight"))
);
return qw;
}
private final static String QWEN2_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
private static Tokenizer createQwen2Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines)
.map(line -> line.split(" "))
.map(parts ->
new Pair<>(
vocabulary.getIndex(parts[0]).orElseThrow(),
vocabulary.getIndex(parts[1]).orElseThrow())
).toList();
int allTokens = vocabulary.size();
int baseTokens = vocabulary.getIndex("<|endoftext|>").orElseThrow(); // assume all tokens after the base ones are special.
int reservedSpecialTokens = allTokens - baseTokens;
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
i -> specialTokensList.get(i),
i -> baseTokens + i)
);
return new Tokenizer(vocabulary, merges, QWEN2_PATTERN, specialTokens, tokenTypes);
}
private static Tokenizer createGPT2Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines)
.map(line -> line.split(" "))
.map(parts ->
new Pair<>(
vocabulary.getIndex(parts[0]).orElseThrow(),
vocabulary.getIndex(parts[1]).orElseThrow())
).toList();
int allTokens = vocabulary.size();
int baseTokens = 128000; // assume all tokens after the base ones are special.
int reservedSpecialTokens = allTokens - baseTokens;
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
i -> specialTokensList.get(i),
i -> baseTokens + i)
);
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
}
private static MistralTokenizer createLlamaTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
t -> vocabulary.get(t),
t -> t)
);
return new MistralTokenizer(vocabulary, null, specialTokens, tokenTypes);
}
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatTensor[] array = new FloatTensor[size];
for (int i = 0; i < size; i++) {
array[i] = loadQuantized(getTensorEntry.apply(i));
}
return array;
}
public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatBuffer[] array = new FloatBuffer[size];
for (int i = 0; i < size; i++) {
array[i] = toFloatBuffer(getTensorEntry.apply(i));
}
return array;
}
public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
GGMLType ggmlType = tensorEntry.ggmlType();
return switch (ggmlType) {
case F32 -> tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
default -> throw new UnsupportedOperationException("Conversion to " + ggmlType);
};
}
}
record Llama(Configuration configuration, TokenizerInterface tokenizer, Weights weights) {
public State createNewState(int batchsize, int beginOfText) {
State state = new State(configuration(), batchsize);
state.latestToken = beginOfText; // was tokenizer.getSpecialTokens().get("<|begin_of_text|>");, now we get from ChatFormat.beginOfText() which does the same
return state;
}
public static final class Configuration {
public final int dim; // transformer dimension
public final int hiddenDim; // for ffn layers
public final int numberOfLayers; // number of layers
public final int numberOfHeads; // number of query heads
public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery)
public final int vocabularySize; // vocabulary size, usually 256 (byte-level)
public final int contextLength; // max sequence length
public final float rmsNormEps;
public final float ropeTheta;
public final int headSize;
Configuration withContextLength(int newContextLength) {
if (newContextLength < 0) {
return this; // no change
}
return new Configuration(this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads, this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.rmsNormEps, this.ropeTheta);
}
public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) {
this.dim = dim;
this.hiddenDim = hiddenDim;
this.numberOfLayers = numberOfLayers;
this.numberOfHeads = numberOfHeads;
this.numberOfKeyValueHeads = numberOfKeyValueHeads;
this.vocabularySize = vocabularySize;
this.contextLength = contextLength;
this.rmsNormEps = rmsNormEps;
this.ropeTheta = ropeTheta;
this.headSize = dim / numberOfHeads;
}
}
public static final class Weights {
// token embedding table
public final FloatTensor token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights
// weights for matmuls
public final FloatTensor[] wq; // (layer, n_heads * head_size)
public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
// next 3: qwen - Groff from Qwen2.java
public FloatTensor[] q_bias = null; // (layer, dim)
public FloatTensor[] k_bias = null; // (layer, kv_dim)
public FloatTensor[] v_bias = null; // (layer, kv_dim)
public final FloatBuffer[] rms_ffn_weight; // (layer, dim)
// weights for ffn
public final FloatTensor[] w1; // (layer, hidden_dim, dim)
public final FloatTensor[] w2; // (layer, dim, hidden_dim)
public final FloatTensor[] w3; // (layer, hidden_dim, dim)
// public final rmsnorm
public final FloatBuffer rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
public final FloatBuffer freq_cis_real; // (seq_len, head_size/2)
public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
public final FloatTensor wcls; // (vocab_size, dim)
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
this.token_embedding_table = token_embedding_table;
this.rms_att_weight = rms_att_weight;
this.wq = wq;
this.wk = wk;
this.wv = wv;
this.wo = wo;
this.rms_ffn_weight = rms_ffn_weight;
this.w1 = w1;
this.w2 = w2;
this.w3 = w3;
this.rms_final_weight = rms_final_weight;
this.freq_cis_real = freq_cis_real;
this.freq_cis_imag = freq_cis_imag;
this.wcls = wcls;
}
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] q, FloatTensor[] k, FloatTensor[] v, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
this.token_embedding_table = token_embedding_table;
this.rms_att_weight = rms_att_weight;
this.wq = wq;
this.wk = wk;
this.wv = wv;
this.wo = wo;
this.q_bias = q;
this.k_bias = k;
this.v_bias = v;
this.rms_ffn_weight = rms_ffn_weight;
this.w1 = w1;
this.w2 = w2;
this.w3 = w3;
this.rms_final_weight = rms_final_weight;
this.freq_cis_real = freq_cis_real;
this.freq_cis_imag = freq_cis_imag;
this.wcls = wcls;
}
}
public static final class State {
// current wave of activations
public final int batchsize;
public final FloatTensor[] x; // activation at current time stamp (dim,)
public final FloatTensor[] xb; // same, but inside a residual branch (dim,)
public final FloatTensor[] xb2; // an additional buffer just for convenience (dim,)
public final FloatTensor[] hb; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] q; // query (dim,)
public final FloatTensor[] k; // key (dim,)
public final FloatTensor[] v; // value (dim,)
public final FloatTensor[] att; // buffer for scores/attention values (n_heads, seq_len)
public final FloatTensor logits; // output logits
// kv cache
public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
/** last index in previous block */
int idxPrevBlock;
public int latestToken;
State(Configuration config, int batchsize) {
this.batchsize = batchsize;
this.x = allocate(batchsize, config.dim);
this.xb = allocate(batchsize, config.dim);
this.xb2 = allocate(batchsize, config.dim);
this.hb = allocate(batchsize, config.hiddenDim);
this.hb2 = allocate(batchsize, config.hiddenDim);
this.q = allocate(batchsize, config.dim);
this.k = allocate(batchsize, config.dim);
this.v = allocate(batchsize, config.dim);
this.att = allocate(batchsize, config.numberOfHeads, config.contextLength);
idxPrevBlock = -1;
this.logits = ArrayFloatTensor.allocate(config.vocabularySize);
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
}
}
static FloatTensor[] allocate(int numTokens, int... dims) {
return IntStream.range(0, numTokens)
.mapToObj(i -> ArrayFloatTensor.allocate(dims))
.toArray(FloatTensor[]::new);
}
static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
// calculate sum of squares
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
// normalize and scale
final float finalss = ss; // for the lambda
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
}
static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
// a few convenience variables
Configuration config = model.configuration();
Weights weights = model.weights();
int dim = config.dim;
int headSize = config.headSize;
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);
final int nTokens = tokens.length;
// copy the token embedding into x
Parallel.parallelFor(0, nTokens, t ->
weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim)
);
// forward all the layers
for (int l = 0; l < config.numberOfLayers; l++) {
// attention rmsnorm
// rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps);
final int curLayer = l;
Parallel.parallelFor(0, nTokens, t ->
rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps)
);
// qkv matmuls for this position
weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim);
weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
Parallel.parallelFor(0, nTokens, t -> {
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = vi == 0 ? state.q[t] : state.k[t]; // the vector to rotate (query or key)
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}
});
// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
Parallel.parallelFor(0, nTokens, t -> {
state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim);
state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
});
// If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
if (!computeLogits && curLayer == config.numberOfLayers - 1) {
state.idxPrevBlock = nTokens - 1;
return null;
}
// multihead attention. iterate over all heads
Parallel.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> {
int token = (int) (ht / config.numberOfHeads);
int h = (int) (ht % config.numberOfHeads);
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength;
// iterate over all timesteps, including the current one
for (int t = 0; t <= position + token; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att[token].setFloat(attOffset + t, score);
}
// softmax the scores to get attention weights, from 0..position inclusively
state.att[token].softmaxInPlace(attOffset, position + token + 1);
// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb[token].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= position + token; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att[token].getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});
// final matmul to get the output of the attention
weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim);
// residual connection back into x
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb2[t]);
});
// ffn rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps);
});
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim);
weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim);
// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
});
// elementwise multiply with w3(x)
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].multiplyInPlace(state.hb2[t]);
});
// final matmul to get the output of the ffn
weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim);
// residual connection
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb[t]);
});
}
// final rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps);
});
if(false) {
SuperBit sb = null;
try (Timer timer = Timer.log("SuperBits:"+state.x[nTokens-1].size())) {
//sb = new SuperBit(state.x[nTokens-1].size());
sb = new SuperBit(100);
}
try (Timer timer = Timer.log("Signature")) {
sb.signature(state.x[nTokens-1]);
}
}
// classifier into logits
weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim);
state.idxPrevBlock = nTokens - 1;
return state.logits;
}
static FloatTensor forwardQwen(Llama model, State state, int token, int position) {
// a few convenience variables
Llama.Configuration config = model.configuration();
Llama.Weights weights = model.weights();
int dim = config.dim;
int headSize = config.headSize;
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);
// copy the token embedding into x
weights.token_embedding_table.copyTo(token * dim, state.x[0], 0, dim);
// forward all the layers
for (int l = 0; l < config.numberOfLayers; l++) {
// attention rmsnorm
rmsnorm(state.xb[0], state.x[0], weights.rms_att_weight[l], dim, config.rmsNormEps);
// qkv matmuls for this position
weights.wq[l].matmul(state.xb[0], state.q[0], dim, dim);
if (weights.q_bias != null && weights.q_bias[l] != null) {
//state.q[0].addInPlace(weights.q_bias[l]);
System.out.println("state:"+state.q[0].size());
state.q[0].verify();
System.out.println("weights:"+weights.q_bias[l].size());
weights.q_bias[l].verify();
state.q[0].addInPlace(weights.q_bias[l]);
}
weights.wk[l].matmul(state.xb[0], state.k[0], kvDim, dim);
if (weights.k_bias != null && weights.k_bias[l] != null) {
state.k[0].addInPlace(weights.k_bias[l]);
}
weights.wv[l].matmul(state.xb[0], state.v[0], kvDim, dim);
if (weights.v_bias != null && weights.v_bias[l] != null) {
state.v[0].addInPlace(weights.v_bias[l]);
}
// RoPE relative positional encoding: complex-valued rotate q and k in each head
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
for (int h = 0; h < config.numberOfHeads; ++h) {
int rotn = h < config.numberOfKeyValueHeads ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
int poffset = h * headSize;
for (int i0 = 0; i0 < headSize; i0 += 2) {
int ic = i0 / 2;
float fcr = weights.freq_cis_real.get(position * (headSize / 2) + ic);
float fci = weights.freq_cis_imag.get(position * (headSize / 2) + ic);
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q[0] : state.k[0]; // the vector to rotate (query or key)
float v0 = vec.getFloat(poffset + ic);
float v1 = vec.getFloat(poffset + ic + headSize/2);
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
}
}
}
// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
state.k[0].copyTo(0, state.keyCache[l], position * kvDim, kvDim);
state.v[0].copyTo(0, state.valueCache[l], position * kvDim, kvDim);
int curLayer = l;
// multihead attention. iterate over all heads
Parallel.parallelFor(0, config.numberOfHeads, h -> {
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength;
// iterate over all timesteps, including the current one
for (int t = 0; t <= position; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q[0].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att[0].setFloat(attOffset + t, score);
}
// softmax the scores to get attention weights, from 0..position inclusively
state.att[0].softmaxInPlace(attOffset, position + 1);
// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb[0].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= position; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att[0].getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb[0].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});
// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb[0], state.xb2[0], dim, dim);
// residual connection back into x
state.x[0].addInPlace(state.xb2[0]);
// ffn rmsnorm
rmsnorm(state.xb[0], state.x[0], weights.rms_ffn_weight[l], dim, config.rmsNormEps);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(state.xb[0], state.hb[0], config.hiddenDim, dim);
weights.w3[l].matmul(state.xb[0], state.hb2[0], config.hiddenDim, dim);
// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
state.hb[0].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
// elementwise multiply with w3(x)
state.hb[0].multiplyInPlace(state.hb2[0]);
// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb[0], state.xb[0], dim, config.hiddenDim);
// residual connection
state.x[0].addInPlace(state.xb[0]);
}
// final rmsnorm
rmsnorm(state.x[0], state.x[0], weights.rms_final_weight, dim, config.rmsNormEps);
// classifier into logits
weights.wcls.matmul(state.x[0], state.logits, config.vocabularySize, dim);
return state.logits;
}
/**
* LLM generation entry point, ingest prompt tokens and generates new tokens.
*
* <p>
* All prompt tokens are ingested first, then inference starts, until a stop token is found.
* The returned tokens only include generated/inferred tokens.
*
* @param model model to run inference (including weights, configuration, tokenizer ...)
* @param state state of the model e.g. key/value caches ... this is mutated by this call
* @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context.
* @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context
* @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
* @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length}
* if this value is negative or greater than {@link Configuration#contextLength context length}
* @param sampler {@link Sampler strategy} used to select tokens
* @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
*/
public static List<Integer> generateTokens(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
long startGen = 0;
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
maxTokens = model.configuration().contextLength;
}
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
int token = state.latestToken; // BOS?
int nextToken;
int promptIndex = 0;
for (int position = startPosition; position < maxTokens; ++position) {
if (promptIndex < promptTokens.size()) {
final int nTokens = Math.min(maxTokens - position, Math.min(promptTokens.size() - promptIndex, state.batchsize));
final int[] tokens = new int[nTokens];
for (int i = 0; i < nTokens; i++) {
tokens[i] = promptTokens.get(promptIndex + i);
if (echo) {
// log prompt token (different color?)
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(tokens[i]))));
}
}
if (echo) {
System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex, promptTokens.size(), Arrays.toString(tokens));
}
// Only compute logits on the very last batch.
boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
forward(model, state, tokens, position, computeLogits);
position += nTokens - 1; // -1 -> incremented later in the for loop
promptIndex += nTokens;
if (promptIndex < promptTokens.size()) {
continue;
}
startGen = System.nanoTime();
} else {
forward(model, state, new int[]{token}, position, true);
}
nextToken = sampler.sampleToken(state.logits);
if (echo) {
// log inferred token
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
generatedTokens.add(nextToken);
if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}
if (stopTokens.contains(nextToken)) {
break;
}
state.latestToken = token = nextToken;
}
long elapsedNanos = System.nanoTime() - startNanos;
long promptNanos = startGen - startNanos;
long genNanos = elapsedNanos - startGen + startNanos;
System.err.printf("%ncontext: %d/%d prompt: %.2f tokens/s (%d) generation: %.2f tokens/s (%d)%n",
startPosition + promptIndex + generatedTokens.size(), model.configuration().contextLength,
promptTokens.size() / (promptNanos / 1_000_000_000.0), promptTokens.size(),
generatedTokens.size() / (genNanos / 1_000_000_000.0), generatedTokens.size());
return generatedTokens;
}
/**
* Qwen specific calls forwardQwen.
* @param model
* @param state
* @param startPosition
* @param promptTokens
* @param stopTokens
* @param maxTokens
* @param sampler
* @param echo
* @param onTokenGenerated
* @return
*/
public static List<Integer> generateTokensQwen(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
maxTokens = model.configuration().contextLength;
}
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
int token = state.latestToken; // BOS?
int nextToken;
int promptIndex = 0;
for (int position = startPosition; position < maxTokens; ++position) {
forwardQwen(model, state, token, position);
if (promptIndex < promptTokens.size()) {
// Force-pick token from prompt.
nextToken = promptTokens.get(promptIndex++);
if (echo) {
// log prompt token (different color?)
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
} else {
nextToken = sampler.sampleToken(state.logits);
if (echo) {
// log inferred token
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
generatedTokens.add(nextToken);
if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}
if (stopTokens.contains(nextToken)) {
break;
}
}
state.latestToken = token = nextToken;
}
long elapsedNanos = System.nanoTime() - startNanos;
int totalTokens = promptIndex + generatedTokens.size();
System.err.printf("%n%.2f tokens/s (%d)%n", totalTokens / (elapsedNanos / 1_000_000_000.0), totalTokens);
return generatedTokens;
}
}
interface TokenizerInterface {
public Map<String, Integer> getSpecialTokens();
public boolean isSpecialToken(int tokenIndex);
public String decode(List<Integer> tokens);
public List<Integer> encodeAsList(String text);
public int getTokenType(int tokenIndex);
}
/**
* Byte Pair Encoding tokenizer.
* <p>
* Based on <a href="https://github.com/karpathy/minbpe">minbpe</a>, algorithmically follows along the
* <a href="https://github.com/openai/gpt-2/blob/master/src/encoder.py">GPT 2 tokenizer</a>
*/
class Tokenizer implements TokenizerInterface {
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map<Pair<Integer, Integer>, Integer> merges;
private final Map<String, Integer> specialTokens;
private int[] tokenTypes; // qwen2
public String regexPattern() {
if (compiledPattern == null) {
return null;
}
return compiledPattern.pattern();
}
@Override
public Map<String, Integer> getSpecialTokens() {
return specialTokens;
}
@Override
public boolean isSpecialToken(int tokenIndex) {
return specialTokens.containsValue(tokenIndex);
}
@Override
public int getTokenType(int tokenIndex) {
return tokenTypes[tokenIndex];
}
public Tokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens) {
this.vocabulary = vocabulary;
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
this.specialTokens = new HashMap<>(specialTokens);
this.merges = new HashMap<>();
for (Pair<Integer, Integer> pair : merges) {
int firstIndex = pair.first();
int secondIndex = pair.second();
int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow();
this.merges.put(pair, mergeIndex);
}
}
public Tokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens, int[] tokenTypes) {
this(vocabulary, merges, regexPattern, specialTokens);
this.tokenTypes = tokenTypes;
}
private int[] encodeImpl(String text) {
return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
}
/**
* Unlike {@link #encodeOrdinary(String)}, this function handles special tokens.
* allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
* if none_raise, then an error is raised if any special token is encountered in text
* this is the default tiktoken behavior right now as well
* any other behavior is either annoying, or a major footgun.
*/
List<Integer> encode(String text, Set<String> allowedSpecial) {
// decode the user desire w.r.t. handling of special tokens
Set<String> special = allowedSpecial;
assert getSpecialTokens().keySet().containsAll(special);
if (special.isEmpty()) {
// shortcut: if no special tokens, just use the ordinary encoding
return encodeOrdinary(text);
}
// otherwise, we have to be careful with potential special tokens in text
// we handle special tokens by splitting the text
// based on the occurrence of any exact match with any of the special tokens
// we can use re.split for this. note that surrounding the pattern with ()
// makes it into a capturing group, so the special tokens will be included
String specialPattern = special
.stream()
.map(Pattern::quote)
.collect(Collectors.joining("|", "(", ")"));
String[] specialChunks = text.split(specialPattern);
// now all the special characters are separated from the rest of the text
// all chunks of text are encoded separately, then results are joined
List<Integer> ids = new ArrayList<>();
for (String part : specialChunks) {
if (special.contains(part)) {
// this is a special token, encode it separately as a special case
ids.add(getSpecialTokens().get(part));
} else {
// this is an ordinary sequence, encode it normally
ids.addAll(encodeOrdinary(part));
}
}
return ids;
}
private static List<String> findAll(Pattern pattern, String text) {
List<String> allMatches = new ArrayList<>();
Matcher matcher = pattern.matcher(text);
while (matcher.find()) {
allMatches.add(matcher.group());
}
return allMatches;
}
/**
* Encoding that ignores any special tokens.
*/
public List<Integer> encodeOrdinary(String text) {
// split text into chunks of text by categories defined in regex pattern
List<String> textChunks = findAll(compiledPattern, text);
// all chunks of text are encoded separately, then results are joined
List<Integer> ids = new ArrayList<>();
for (String chunk : textChunks) {
List<Integer> chunkIds = encodeChunk(chunk);
ids.addAll(chunkIds);
}
return ids;
}
private Map<Pair<Integer, Integer>, Integer> getStats(List<Integer> ids) {
Map<Pair<Integer, Integer>, Integer> map = new HashMap<>();
for (int i = 0; i + 1 < ids.size(); i++) {
Pair<Integer, Integer> key = new Pair<>(ids.get(i), ids.get(i + 1));
map.put(key, map.getOrDefault(key, 0) + 1);
}
return map;
}
private List<Integer> encodeChunk(String chunk) {
// return the token ids
// let's begin. first, convert all bytes to integers in range 0..255
List<Integer> ids = new ArrayList<>();
for (int b : chunk.toCharArray()) {
int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow();
ids.add(tokenIndex);
}
while (ids.size() >= 2) {
// find the pair with the lowest merge index
Map<Pair<Integer, Integer>, Integer> stats = getStats(ids);
Pair<Integer, Integer> pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow();
// subtle: if there are no more merges available, the key will
// result in an inf for every single pair, and the min will be
// just the first pair in the list, arbitrarily
// we can detect this terminating case by a membership check
if (!this.merges.containsKey(pair)) {
break; // nothing else can be merged anymore
}
// otherwise let's merge the best pair (lowest merge index)
int idx = this.merges.get(pair);
ids = merge(ids, pair, idx);
}
return ids;
}
private static List<Integer> merge(List<Integer> ids, Pair<Integer, Integer> pair, int idx) {
List<Integer> newids = new ArrayList<>();
int i = 0;
while (i < ids.size()) {
// if not at the very last position AND the pair matches, replace it
if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
newids.add(idx);
i += 2;
} else {
newids.add(ids.get(i));
i += 1;
}
}
return newids;
}
public String decodeImpl(List<Integer> tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
String tokenString = vocabulary.get(token);
sb.append(tokenString);
}
return sb.toString();
}
/**
* Returns list of utf-8 byte and a corresponding list of unicode strings.
* The reversible bpe codes work on unicode strings.
* This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
* When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
* This is a significant percentage of your normal, say, 32K bpe vocab.
* To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
* And avoids mapping to whitespace/control characters the bpe code barfs on.
*/
private static Map<Integer, Integer> bytesToUnicode() {
List<Integer> bs = new ArrayList<>();
IntStream.rangeClosed('!', '~').forEach(bs::add);
IntStream.rangeClosed('¡', '¬').forEach(bs::add);
IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
List<Integer> cs = new ArrayList<>(bs);
int n = 0;
for (int b = 0; b < 256; ++b) {
if (!bs.contains(b)) {
bs.add(b);
cs.add(256 + n);
n += 1;
}
}
// return dict(zip(bs, cs))
return IntStream.range(0, bs.size())
.boxed()
.collect(Collectors.toMap(bs::get, cs::get));
}
static final Map<Integer, Integer> BYTE_ENCODER = bytesToUnicode();
static final Map<Integer, Integer> BYTE_DECODER = BYTE_ENCODER.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
public int[] encode(String text) {
StringBuilder sb = new StringBuilder();
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
for (byte b : bytes) {
sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
}
return encodeImpl(sb.toString());
}
public static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
// which distort the output (e.g. \n or much worse)
// https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
// http://www.unicode.org/reports/tr44/#GC_Values_Table\
StringBuilder chars = new StringBuilder();
for (int cp : codePoints) {
if (Character.getType(cp) == Character.CONTROL && cp != '\n') {
chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape
} else {
chars.appendCodePoint(cp); // this character is ok
}
}
return chars.toString();
}
public static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
@Override
public List<Integer> encodeAsList(String text) {
return Arrays.stream(encode(text)).boxed().toList();
}
@Override
public String decode(List<Integer> tokens) {
String decoded = decodeImpl(tokens);
int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray();
byte[] rawBytes = new byte[decodedBytesAsInts.length];
for (int i = 0; i < decoded.length(); i++) {
rawBytes[i] = (byte) decodedBytesAsInts[i];
}
return new String(rawBytes, StandardCharsets.UTF_8);
}
}
/**
* Wherein Llama models metadata.get("tokenizer.ggml.model") = gpt2
* and Mistral uses metadata.get("tokenizer.ggml.model") = llama.
*/
class MistralTokenizer implements TokenizerInterface {
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map<String, Integer> specialTokens;
private final int[] tokenType;
private final int byte0;
public String regexPattern() {
if (compiledPattern == null) {
return null;
}
return compiledPattern.pattern();
}
@Override
public Map<String, Integer> getSpecialTokens() {
return specialTokens;
}
@Override
public boolean isSpecialToken(int tokenIndex) {
return getTokenType(tokenIndex) != 1;
}
@Override
public int getTokenType(int tokenIndex) {
return tokenType[tokenIndex];
}
public MistralTokenizer(Vocabulary vocabulary, String regexPattern, Map<String, Integer> specialTokens, int[] tokenType) {
this.vocabulary = vocabulary;
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
this.specialTokens = new HashMap<>(specialTokens);
this.tokenType = tokenType;
this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
}
List<Integer> encode(String text) {
return encodeImpl(text.replace(' ', '▁'));
}
private List<Integer> encodeImpl(String text) {
List<Integer> tokens = new ArrayList<>();
// first encode every individual codepoint in the input string
for (int i = 0, cpi; i < text.length(); i += Character.charCount(cpi)) {
cpi = text.codePointAt(i);
String singleCodepoint = Character.toString(cpi);
int id = vocabulary.getIndex(singleCodepoint).orElse(-1);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens.add(id);
} else {
// byte_fallback encoding: just encode each byte as a token
// +byte0 here to skip all the control and special tokens e.g. <unk>, <s>, </s>
// so the individual bytes only start at token <0x00>
for (byte b : singleCodepoint.getBytes(StandardCharsets.UTF_8)) {
tokens.add(Byte.toUnsignedInt(b) + byte0);
}
}
}
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (true) {
float best_score = -1e10f;
int best_id = -1;
int best_idx = -1;
for (int i = 0; i < tokens.size() - 1; ++i) {
// check if we can merge the pair (tokens[i], tokens[i+1])
String str_buffer = vocabulary.get(tokens.get(i)) + vocabulary.get(tokens.get(i + 1));
int id = vocabulary.getIndex(str_buffer).orElse(-1);
if (id != -1 && vocabulary.getScore(id) > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocabulary.getScore(id);
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens.set(best_idx, best_id);
tokens.remove(best_idx + 1);
}
return tokens;
}
@Override
public String decode(List<Integer> tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
String tokenString = vocabulary.get(token);
if (isSpecialToken(token)) {
// some tokens designate raw bytes e.g. '<0x10>'
String prefix = "<0x";
String suffix = ">";
if (tokenString.length() == 6 && tokenString.startsWith(prefix) && tokenString.endsWith(suffix)) {
String code = tokenString.substring(prefix.length(), tokenString.length() - suffix.length());
int cp = Integer.parseInt(code, 16);
tokenString = Character.toString(cp);
}
} else {
tokenString = tokenString.replace('▁', ' ');
}
sb.append(tokenString);
}
return sb.toString();
}
public static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
// which distort the output (e.g. \n or much worse)
// https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
// http://www.unicode.org/reports/tr44/#GC_Values_Table\
StringBuilder chars = new StringBuilder();
for (int cp : codePoints) {
if (Character.getType(cp) == Character.CONTROL && cp != '\n') {
chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape
} else {
chars.appendCodePoint(cp); // this character is ok
}
}
return chars.toString();
}
public static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
public List<Integer> encodeAsList(String text) {
return encode(text);
}
}
final class Parallel {
public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
IntStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
LongStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
}
record Pair<First, Second>(First first, Second second) {
}
record GGMLTensorEntry(MemorySegment mappedFile, String name, GGMLType ggmlType, int[] shape,
MemorySegment memorySegment) {
}
enum GGMLType {
F32(Float.BYTES),
F16(GGMLType.FLOAT16_BYTES),
Q4_0(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_1(2 * GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
UNSUPPORTED_Q4_2(Integer.MAX_VALUE), // support has been removed
UNSUPPORTED_Q4_3(Integer.MAX_VALUE), // support has been removed
Q5_0(Integer.MAX_VALUE),
Q5_1(Integer.MAX_VALUE),
Q8_0(GGMLType.FLOAT16_BYTES + 32 * Byte.BYTES, 32),
Q8_1(32 * Byte.BYTES + 2 * Float.BYTES, 32),
// k-quantizations
Q2_K(Integer.MAX_VALUE),
Q3_K(Integer.MAX_VALUE),
Q4_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 2, GGMLType.QK_K),
Q5_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 8 + GGMLType.QK_K / 2, GGMLType.QK_K),
Q6_K(GGMLType.QK_K / 2 + GGMLType.QK_K / 4 + GGMLType.QK_K / 16 + GGMLType.FLOAT16_BYTES, GGMLType.QK_K),
Q8_K(Integer.MAX_VALUE),
IQ2_XXS(Integer.MAX_VALUE),
IQ2_XS(Integer.MAX_VALUE),
IQ3_XXS(Integer.MAX_VALUE),
IQ1_S(Integer.MAX_VALUE),
IQ4_NL(Integer.MAX_VALUE),
IQ3_S(Integer.MAX_VALUE),
IQ2_S(Integer.MAX_VALUE),
IQ4_XS(Integer.MAX_VALUE),
I8(Byte.BYTES),
I16(Short.BYTES),
I32(Integer.BYTES),
I64(Long.BYTES),
F64(Double.BYTES),
IQ1_M(Integer.MAX_VALUE),
BF16(GGMLType.BFLOAT16_BYTES),
Q4_0_4_4(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_0_4_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_0_8_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
TQ1_0(Integer.MAX_VALUE),
TQ2_0(Integer.MAX_VALUE);
public static final int BFLOAT16_BYTES = 2;
public static final int FLOAT16_BYTES = 2;
private static final GGMLType[] VALUES = values();
private final int typeSize;
private final int blockSize;
public int getTypeSize() {
return typeSize;
}
public int getBlockSize() {
return blockSize;
}
public static GGMLType fromId(int id) {
return VALUES[id];
}
GGMLType(int typeSize) {
this(typeSize, 1);
}
public long byteSizeFor(int numberOfElements) {
long t = numberOfElements * (long) getTypeSize();
assert t % getBlockSize() == 0;
return Math.toIntExact(t / getBlockSize());
}
public static final int QK_K = 256; // or 64?
GGMLType(int typeSize, int blockSize) {
assert blockSize > 0;
assert typeSize > 0;
assert isPowerOf2(blockSize);
this.typeSize = typeSize;
this.blockSize = blockSize;
}
private static boolean isPowerOf2(int n) {
return n > 0 && (n & (n - 1)) == 0;
}
}
/**
* Over-simplified, shapeless, float tensor.
* <p>
* Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory
* e.g. can represent a sequence of quantized floats.
*/
abstract class FloatTensor implements Externalizable, Comparable {
static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize());
static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0;
static short readShort(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_SHORT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static int readInt(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_INT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static float readFloat(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_FLOAT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static byte readByte(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_BYTE, offset);
//return UNSAFE.getByte(memorySegment.address() + offset);
}
// Preferred vector size for the fast multiplication routines.
// (Apple Silicon) NEON only supports up-to 128bit vectors.
static final VectorSpecies<Float> F_SPECIES;
static final VectorSpecies<Integer> I_SPECIES;
static final VectorSpecies<Short> S_SPECIES_HALF;
static {
if (USE_VECTOR_API) {
F_SPECIES = VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class);
I_SPECIES = F_SPECIES.withLanes(int.class);
S_SPECIES_HALF = VectorShape.forBitSize(F_SPECIES.vectorBitSize() / 2).withLanes(short.class);
assert F_SPECIES.length() == S_SPECIES_HALF.length();
} else {
F_SPECIES = null;
I_SPECIES = null;
S_SPECIES_HALF = null;
}
}
abstract int size();
abstract float getFloat(int index);
abstract void setFloat(int index, float value);
abstract FloatVector getFloatVector(VectorSpecies<Float> species, int offset);
abstract GGMLType type();
public static int numberOfElements(int... dimensions) {
assert Arrays.stream(dimensions).allMatch(i -> i > 0);
return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow();
}
static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) {
float result = 0f;
for (int j = 0; j < size; j++) {
result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j);
}
return result;
}
float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
return scalarDot(this, thisOffset, that, thatOffset, size);
}
void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) {
Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1)));
}
void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) {
if (that.length != out.length) {
throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length));
}
Parallel.parallelForLong(0, dim0 * context, ti -> {
int idxArr = (int) (ti / dim0);
int i = (int) (ti % dim0);
out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1));
});
}
@FunctionalInterface
interface AggregateFunction {
float apply(float acc, float value);
}
float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) {
float result = seed;
for (int i = 0; i < size; ++i) {
result = reduce.apply(result, getFloat(thisOffset + i));
}
return result;
}
float sum(int thisOffset, int size) {
return reduce(thisOffset, size, 0f, Float::sum);
}
float max(int thisOffset, int size) {
return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max);
}
void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset));
}
int argmax(int thisOffset, int size) {
assert size > 0;
int maxIndex = thisOffset;
float maxValue = this.getFloat(maxIndex);
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
float f = this.getFloat(i);
if (f > maxValue) {
maxValue = f;
maxIndex = i;
}
}
return maxIndex;
}
int argmax() {
return argmax(0, size());
}
@FunctionalInterface
interface MapFunction {
float apply(float value);
}
@FunctionalInterface
interface MapWithIndexFunction {
float apply(float value, int index);
}
FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) {
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
setFloat(i, mapFunction.apply(getFloat(i)));
}
return this;
}
FloatTensor mapInPlace(MapFunction mapFunction) {
return mapInPlace(0, size(), mapFunction);
}
FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) {
int endOffset = thisOffset + size;
for (int i = thisOffset; i < endOffset; ++i) {
System.out.println("setFloat:"+i+" of size:"+size);
setFloat(i, mapWithIndexFunction.apply(getFloat(i), i));
}
return this;
}
FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor addInPlace(FloatTensor that) {
return addInPlace(0, that, 0, size());
}
FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor multiplyInPlace(FloatTensor that) {
return multiplyInPlace(0, that, 0, size());
}
FloatTensor divideInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, f -> f / value);
}
FloatTensor fillInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, unused -> value);
}
FloatTensor softmaxInPlace(int thisOffset, int size) {
// find max value (for numerical stability)
float maxVal = max(thisOffset, size);
// exp and sum
mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal));
float sum = sum(thisOffset, size);
// normalize
return divideInPlace(thisOffset, size, sum);
}
FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) {
// this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size)
for (int i = 0; i < size; ++i) {
setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i));
}
return this;
}
float cosineSimilarity(FloatTensor a, FloatTensor b) {
float dotProduct = a.dot(0, b, 0, a.size());
DoubleAdder aNormAdder = new DoubleAdder();
DoubleAdder bNormAdder = new DoubleAdder();
Parallel.parallelFor(0, a.size(), t -> {
aNormAdder.add(a.getFloat(t) * a.getFloat(t));
bNormAdder.add(b.getFloat(t) * b.getFloat(t));
});
float aNorm = (float) Math.sqrt(aNormAdder.sum());
float bNorm = (float) Math.sqrt(bNormAdder.sum());
return (dotProduct / (aNorm * bNorm));
}
public void verify() {
System.out.println("size:"+size());
System.out.println("Verified via String of length:"+toString().length());
}
public String toString() {
StringBuilder sb = new StringBuilder("[");
for(int i = 0; i < size(); i++) {
sb.append(getFloat(i));
if(i == (size()-1))
sb.append("]");
else
sb.append(",");
}
return sb.toString();
}
}
/**
* {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format.
* <p>
* This tensor implementation is not compatible with {@link FloatTensor}, but
* {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when
* the second argument implements {@link FloatTensor}.
*/
final class Q4_0FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public Q4_0FloatTensor() {}
public Q4_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q4_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
byte quant;
int modIndex = index % GGMLType.Q4_0.getBlockSize();
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
quant = (byte) (readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + modIndex) & 0x0F);
} else {
quant = (byte) ((readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
}
quant -= 8;
return quant * scale;
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + j to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize();
int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var loBytes = wBytes.and((byte) 0xF).sub((byte) 8);
var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 0));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 0));
val = sum0.add(sum2).fma(wScale, val);
}
case 256 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 0));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var tmp = i == 0 ? loBytes : hiBytes;
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((Q4_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((Q4_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class Q8_0FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public Q8_0FloatTensor() {}
public Q8_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q8_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q8_0.getBlockSize();
int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize();
byte quant = readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + withinBlockIndex);
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
return quant * scale;
}
public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN);
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + startIndex to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize();
int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).fma(wScale, val);
}
case 256 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + i * ByteVector.SPECIES_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((Q8_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((Q8_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class BF16FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public BF16FloatTensor() {}
public BF16FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.BF16;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
return bfloat16ToFloat(readShort(memorySegment, index * GGMLType.BFLOAT16_BYTES));
}
private float bfloat16ToFloat(short bfloat16) {
return Float.intBitsToFloat(bfloat16 << 16);
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(BF16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
assert S_SPECIES_HALF.length() == F_SPECIES.length();
FloatVector val = FloatVector.zero(F_SPECIES);
int upperBound = F_SPECIES.loopBound(size);
for (int i = 0; i < upperBound; i += F_SPECIES.length()) {
FloatVector thatVector = that.getFloatVector(F_SPECIES, thatOffset + i);
ShortVector bfloat16 = ShortVector.fromMemorySegment(S_SPECIES_HALF, thiz.memorySegment, (thisOffset + i) * (long) GGMLType.BFLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
// BFloat16 to Float32 Conversion:
//
// ┌─[15]─┬─[14]───····───[7]─┬─[6]────····────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (7 bits) │ BFloat16 Layout (16 bits)
// └──────┴───────────────────┴────────────────────┘
// │ │ │
// ▼ ▼ ▼
// ┌─[31]─┬─[30]───···───[23]─┬─[22]────···────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (23 bits) │ Float32 Layout (32 bits)
// └──────┴───────────────────┴────────────────────┘
FloatVector thizVector = bfloat16
.castShape(I_SPECIES, 0) // (int) vi
.lanewise(VectorOperators.LSHL, 16) // vi <<= 16
.reinterpretAsFloats(); // Float.intBitsToFloat(vi)
val = thizVector.fma(thatVector, val);
}
float result = val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (upperBound < size) {
result += scalarDot(thiz, thisOffset + upperBound, that, thatOffset + upperBound, size - upperBound);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((BF16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((BF16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class F16FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public F16FloatTensor() {}
public F16FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.F16;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
return Float.float16ToFloat(readShort(memorySegment, index * GGMLType.FLOAT16_BYTES));
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(F16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
assert S_SPECIES_HALF.length() == F_SPECIES.length();
FloatVector val = FloatVector.zero(F_SPECIES);
int upperBound = F_SPECIES.loopBound(size);
for (int i = 0; i < upperBound; i += F_SPECIES.length()) {
FloatVector thatVector = that.getFloatVector(F_SPECIES, thatOffset + i);
ShortVector bits16 = ShortVector.fromMemorySegment(S_SPECIES_HALF, thiz.memorySegment, (thisOffset + i) * (long) GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var bits32 = bits16.castShape(I_SPECIES, 0).reinterpretAsInts(); // (int) bits16
// Does not support infinities nor NaNs, preserves sign, emulate DAZ (denormals-are-zero).
// Expects well-formed float16 values only (e.g. model weights).
// Fast Float16 to Float32 Conversion:
//
// ┌─[15]─┬─[14]───···───[10]─┬─[9]────····────[0]─┐
// │ Sign │ Exponent (5 bits) │ Mantissa (10 bits) │ Float16 Layout (16 bits)
// └──────┴───────────────────┴────────────────────┘
// │ │ │
// ▼ ▼ ▼
// ┌─[31]─┬─[30]───···───[23]─┬─[22]────···────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (23 bits) │ Float32 Layout (32 bits)
// └──────┴───────────────────┴────────────────────┘
//
// Shifts and adjustments:
// - Sign: float16[15] -> float32[31] (shift 16 bits up)
// - Exponent: float16[10-14] -> float32[23-30] (+ bias adjustment)
// - Mantissa: float16[0-9] -> float32[13-22] (shift 13 bits up)
//
// exp = bits32 & 0x7C00
// zeroExponentMask = exp == 0 ? 0 : ~0
var zeroExponentMask = bits32.and(0x7C00).neg().lanewise(VectorOperators.ASHR, 31); // = (-exp) >> 31
bits32 = bits32.and(0x8000).lanewise(VectorOperators.LSHL, 16) // sign
.or(
// exponent and mantissa combined
bits32.and(0x7FFF).add(0x1C000).lanewise(VectorOperators.LSHL, 13)
.and(zeroExponentMask) // -0, +0 and DAZ (denormals-are-zero)
);
FloatVector thizVector = bits32.reinterpretAsFloats(); // Float.intBitsToFloat(vi)
val = thizVector.fma(thatVector, val);
}
float result = val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (upperBound < size) {
result += scalarDot(thiz, thisOffset + upperBound, that, thatOffset + upperBound, size - upperBound);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((F16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((F16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class F32FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public F32FloatTensor() {}
public F32FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
float getFloat(int index) {
assert 0 <= index && index < size;
return readFloat(memorySegment, index * 4);
}
@Override
void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int offset) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
GGMLType type() {
return GGMLType.F32;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((F32FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((F32FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class ArrayFloatTensor extends FloatTensor implements Externalizable, Comparable {
float[] values;
public ArrayFloatTensor() {}
ArrayFloatTensor(float[] values) {
this.values = values;
}
public static FloatTensor allocate(int... dims) {
int numberOfElements = FloatTensor.numberOfElements(dims);
return new ArrayFloatTensor(new float[numberOfElements]);
}
@Override
public int size() {
return values.length;
}
@Override
public float getFloat(int index) {
return values[index];
}
@Override
public void setFloat(int index, float value) {
values[index] = value;
}
@Override
public GGMLType type() {
return GGMLType.F32;
}
@Override
public FloatTensor fillInPlace(int thisOffset, int size, float value) {
Arrays.fill(values, thisOffset, thisOffset + size, value);
return this;
}
@Override
public FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
if (!USE_VECTOR_API) {
throw new UnsupportedOperationException();
}
return FloatVector.fromArray(species, values, index);
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(values.length);
for(float v: values)
out.writeFloat(v);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
int vsize = in.readInt();
values = new float[vsize];
for(int i = 0; i < vsize; i++)
values[i]= in.readFloat();
}
@Override
public int compareTo(Object o) {
return Arrays.compare(values,((ArrayFloatTensor)o).values);
}
}
final class RoPE {
/**
* For GPT2 vocab
* @param contextLength
* @param headSize
* @param theta
* @param ropeScaling
* @param scaleFactor
* @param loFreqFactor
* @param hiFreqFactor
* @param oldContextLength
* @return
*/
public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int headSize, double theta,
boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];
int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
if (ropeScaling) {
// Llama 3.1 scaling
float loFreqWavelen = oldContextLength / loFreqFactor;
float hiFreqWavelen = oldContextLength / hiFreqFactor;
float wavelen = (float) (2.0 * Math.PI / freq);
if (wavelen < hiFreqWavelen) {
freq = freq;
} else if (wavelen > loFreqWavelen) {
freq = freq / scaleFactor;
} else {
float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor);
freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq;
}
}
float val = pos * freq;
cr[n] = (float) Math.cos(val);
ci[n] = (float) Math.sin(val);
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}
/**
* for Llama vocab
* @param contextLength
* @param headSize
* @param theta
* @return
*/
public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int headSize, double theta) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];
int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
float val = pos * freq;
cr[n] = (float) Math.cos(val);
ci[n] = (float) Math.sin(val);
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}
}
record Vocabulary(String[] tokens, float[] scores, Map<String, Integer> tokenToIndex) {
public Vocabulary(String[] vocabulary, float[] scores) {
this(vocabulary, scores,
IntStream.range(0, vocabulary.length)
.boxed()
.collect(Collectors.toMap(i -> vocabulary[i], i -> i))
);
}
public String get(int tokenIndex) {
return tokens[tokenIndex];
}
public OptionalInt getIndex(String token) {
Integer value = tokenToIndex.get(token);
return value != null ? OptionalInt.of(value) : OptionalInt.empty();
}
public int size() {
return tokens.length;
}
/**
* Added from Mistral Vocabulary - Groff
* @param tokenIndex
* @return
*/
public float getScore(int tokenIndex) {
return scores[tokenIndex];
}
public boolean scoresNull() {
return scores == null;
}
}
@FunctionalInterface
interface Sampler {
int sampleToken(FloatTensor logits);
Sampler ARGMAX = FloatTensor::argmax;
}
record CategoricalSampler(RandomGenerator rng) implements Sampler {
@Override
public int sampleToken(FloatTensor logits) {
// sample index from probabilities (they must sum to 1!)
float random0to1 = rng.nextFloat(1f);
float cdf = 0.0f;
for (int i = 0; i < logits.size(); i++) {
cdf += logits.getFloat(i);
if (random0to1 < cdf) {
return i;
}
}
return logits.size() - 1; // in case of rounding errors
}
}
final class ToppSampler implements Sampler {
final int[] indices;
final float topp;
final RandomGenerator rng;
public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) {
this.indices = new int[maxNumberOfElements];
this.topp = topp;
this.rng = rng;
}
static void swap(int[] array, int from, int to) {
int tmp = array[from];
array[from] = array[to];
array[to] = tmp;
}
static void siftDown(int[] array, int from, int n, Comparator<Integer> comparator) {
int prev = from, next;
while ((next = 2 * prev + 1) < n) {
int r = 2 * prev + 2;
if (r < n && comparator.compare(array[r], array[next]) < 0) {
next = r;
}
if (comparator.compare(array[next], array[prev]) < 0) {
swap(array, prev, next);
prev = next;
} else {
break;
}
}
}
@Override
public int sampleToken(FloatTensor logits) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
Comparator<Integer> comparator = Comparator.comparingDouble(logits::getFloat).reversed();
int n = logits.size();
int head = 0;
int tail = n - 1;
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < indices.length; i++) {
if (logits.getFloat(i) >= cutoff) {
indices[head++] = i;
} else {
indices[tail--] = i;
}
}
int n0 = head;
// build heap O(n0)
for (int i = n0 / 2 - 1; i >= 0; --i) {
siftDown(indices, i, n0, comparator);
}
// truncate the list where cumulative probability of the largest k elements exceeds topp
// O(k lg n0)
float cumulativeProb = 0.0f;
int lastIndex = 0;
for (int i = n0 - 1; i >= 0; i--) {
swap(indices, 0, i);
cumulativeProb += logits.getFloat(indices[i]);
if (cumulativeProb > topp) {
lastIndex = i;
break; // we've exceeded topp by including lastIndex
}
siftDown(indices, 0, i - 1, comparator);
}
// sample from the truncated list
float r = rng.nextFloat(1f) * cumulativeProb;
float cdf = 0.0f;
for (int i = n0 - 1; i >= lastIndex; i--) {
cdf += logits.getFloat(indices[i]);
if (r < cdf) {
return indices[i];
}
}
return indices[lastIndex]; // in case of rounding errors
}
}
interface ChatFormatInterface {
public TokenizerInterface getTokenizer();
public Set<Integer> getStopTokens();
public List<Integer> encodeHeader(ChatFormat.Message message);
public List<Integer> encodeMessage(ChatFormat.Message message);
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog);
public int getBeginOfText();
}
/**
* Utility tailored for Llama 3 instruct prompt format.
*/
class ChatFormat implements ChatFormatInterface {
final Tokenizer tokenizer;
final int beginOfText;
final int endHeader;
final int startHeader;
final int endOfTurn;
final int endOfText;
final int endOfMessage;
final Set<Integer> stopTokens;
public ChatFormat(TokenizerInterface tokenizer) {
this.tokenizer = (Tokenizer)tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.beginOfText = specialTokens.get("<|begin_of_text|>");
this.startHeader = specialTokens.get("<|start_header_id|>");
this.endHeader = specialTokens.get("<|end_header_id|>");
this.endOfTurn = specialTokens.get("<|eot_id|>");
this.endOfText = specialTokens.get("<|end_of_text|>");
this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1
this.stopTokens = Set.of(endOfText, endOfTurn);
}
@Override
public TokenizerInterface getTokenizer() {
return tokenizer;
}
@Override
public Set<Integer> getStopTokens() {
return stopTokens;
}
@Override
public int getBeginOfText() {
return beginOfText;
}
@Override
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(startHeader);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.add(endHeader);
tokens.addAll(this.tokenizer.encodeAsList("\n"));
return tokens;
}
@Override
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = this.encodeHeader(message);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(endOfTurn);
return tokens;
}
@Override
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(beginOfText);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
if (appendAssistantTurn) {
// Add the start of an assistant message for the model to complete.
tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
}
return tokens;
}
public record Message(ChatFormat.Role role, String content) {
}
public record Role(String name) {
public static ChatFormat.Role SYSTEM = new ChatFormat.Role("system");
public static ChatFormat.Role USER = new ChatFormat.Role("user");
public static ChatFormat.Role ASSISTANT = new ChatFormat.Role("assistant");
@Override
public String toString() {
return name;
}
}
}
/**
* Utility tailored for Mistral v0.3 instruct prompt format.
*/
final class MistralChatFormat implements ChatFormatInterface {
protected final MistralTokenizer tokenizer;
protected final int unknownToken;
protected final int beginOfText;
protected final int endOfText;
protected final int beginOfInstruction;
protected final int endOfInstruction;
protected final int toolCalls;
protected final int beginOfAvailableTools;
protected final int endOfAvailableTools;
protected final int beginOfToolResults;
protected final int endOfToolResults;
protected final int prefix;
protected final int middle;
protected final int suffix;
public MistralChatFormat(TokenizerInterface tokenizer) {
this.tokenizer = (MistralTokenizer)tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.unknownToken = specialTokens.get("<unk>");
this.beginOfText = specialTokens.get("<s>");
this.endOfText = specialTokens.get("</s>");
this.beginOfInstruction = specialTokens.get("[INST]");
this.endOfInstruction = specialTokens.get("[/INST]");
this.toolCalls = specialTokens.get("[TOOL_CALLS]");
this.beginOfAvailableTools = specialTokens.get("[AVAILABLE_TOOLS]");
this.endOfAvailableTools = specialTokens.get("[/AVAILABLE_TOOLS]");
this.beginOfToolResults = specialTokens.get("[TOOL_RESULTS]");
this.endOfToolResults = specialTokens.get("[/TOOL_RESULTS]");
// Only Codestral supports FIM tokens.
this.prefix = specialTokens.getOrDefault("[PREFIX]", unknownToken);
this.suffix = specialTokens.getOrDefault("[SUFFIX]", unknownToken);
this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken);
}
@Override
public TokenizerInterface getTokenizer() {
return tokenizer;
}
@Override
public Set<Integer> getStopTokens() {
return Set.of(endOfText);
}
@Override
public int getBeginOfText() {
return beginOfText;
}
public List<Integer> encodeMessage(String userMessage, boolean addHeader, boolean addFooter) {
List<Integer> tokens = new ArrayList<>();
if (addHeader) {
tokens.add(this.beginOfInstruction);
}
if (userMessage != null) {
tokens.addAll(this.tokenizer.encodeAsList(userMessage.strip()));
}
if (addFooter) {
tokens.add(endOfInstruction);
}
return tokens;
}
public List<Integer> encodeFillInTheMiddle(String prefix, String suffix) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.suffix);
tokens.addAll(tokenizer.encode(suffix));
tokens.add(this.prefix);
tokens.addAll(tokenizer.encode(prefix));
return tokens;
}
@Override
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.beginOfInstruction);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.add(endOfInstruction);
return tokens;
}
@Override
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.beginOfInstruction);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(endOfInstruction);
return tokens;
}
@Override
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(beginOfText);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
//if (appendAssistantTurn) {
// // Add the start of an assistant message for the model to complete.
// tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
//}
tokens.add(endOfText);
return tokens;
}
}
/**
* Utility tailored for the Chat Markup Language (ChatML) Qwen prompt format.
*/
class ChatMLFormat implements ChatFormatInterface {
protected final TokenizerInterface tokenizer;
protected final int imStart;
protected final int endOfText;
protected final int imEnd;
public ChatMLFormat(TokenizerInterface tokenizer) {
this.tokenizer = tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.imStart = specialTokens.get("<|im_start|>");
this.imEnd = specialTokens.get("<|im_end|>");
this.endOfText = specialTokens.get("<|endoftext|>");
}
public TokenizerInterface getTokenizer() {
return tokenizer;
}
public Set<Integer> getStopTokens() {
return Set.of(imEnd, endOfText);
}
@Override
public int getBeginOfText() {
return imStart;
}
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(imStart);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.addAll(this.tokenizer.encodeAsList("\n"));
return tokens;
}
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = this.encodeHeader(message);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(imEnd);
return tokens;
}
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(imStart);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
if (appendAssistantTurn) {
// Add the start of an assistant message for the model to complete.
tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
}
return tokens;
}
}
/**
* Support for AOT preloading of GGUF metadata with GraalVM's Native Image.
*
* <p>
* To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf}
* to the native-image builder command. At runtime, the preloaded model will be used
* iff the specified and preloaded file names (base name) match.
*/
final class AOT {
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {}
private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF"));
private static PartialModel preLoadGGUF(String modelPath) {
if (modelPath == null || modelPath.isEmpty()) {
return null;
}
try {
Path path = Path.of(modelPath);
if (!Files.exists(path) || !Files.isRegularFile(path)) {
throw new IllegalArgumentException("Cannot pre-load model: " + path);
}
GGUF gguf = GGUF.loadModel(path);
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
return new PartialModel(
path.getFileName().toString(),
ModelLoader.loadModel(fileChannel, gguf, Llama3.Options.DEFAULT_MAX_TOKENS, false),
gguf.getTensorDataOffset(),
gguf.getTensorInfos()
);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Tries to reuse a compatible AOT preloaded model.
* The file name (base name) must match with the preloaded file name.
* No checksum/hash is checked for performance reasons.
*/
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
if (preLoaded == null) {
return null; // no pre-loaded model stored
}
String optionsModel = modelPath.getFileName().toString();
String preLoadedModel = preLoaded.modelFileName();
if (!Objects.equals(optionsModel, preLoadedModel)) {
// Preloaded and specified model file names didn't match.
return null;
}
Llama baseModel = preLoaded.model();
try (var timer = Timer.log("Load tensors from pre-loaded model");
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
// Load only the tensors (mmap slices).
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
Llama.Weights weights = ModelLoader.loadGPT2Weights(tensorEntries, baseModel.configuration());
return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights);
}
}
}
/**
* Implementation of Super-Bit Locality-Sensitive Hashing.
* Super-Bit is an improvement of Random Projection LSH.
* It computes an estimation of cosine similarity.
*
* Super-Bit Locality-Sensitive Hashing
* Jianqiu Ji, Jianmin Li, Shuicheng Yan, Bo Zhang, Qi Tian
* http://papers.nips.cc/paper/4847-super-bit-locality-sensitive-hashing.pdf
* Advances in Neural Information Processing Systems 25, 2012
*
* Supported input types:
* - double[]
* @author original:Thibault Debatty
* @author Groff
*/
final class SuperBit implements java.io.Serializable, Comparable {
private static final long serialVersionUID = -1L;
private boolean[] sig;
private transient double[][] hyperplanes;
private static final int DEFAULT_CODE_LENGTH = 10000;
/**
* Initialize SuperBit algorithm.
* Super-Bit depth n must be [1 .. d] and number of Super-Bit l in [1 ..
* The resulting code length k = n * l
* The K vectors are orthogonalized in L batches of N vectors
*
* @param d data space dimension
* @param n Super-Bit depth [1 .. d]
* @param l number of Super-Bit [1 ..
*/
public SuperBit(final int d, final int n, final int l) {
this(d, n, l, new Random());
}
/**
* Initialize SuperBit algorithm.
* Super-Bit depth n must be [1 .. d] and number of Super-Bit l in [1 ..
* The resulting code length k = n * l
* The K vectors are orthogonalized in L batches of N vectors
*
* @param d data space dimension
* @param n Super-Bit depth [1 .. d]
* @param l number of Super-Bit [1 ..
* @param seed to use for the random number generator
*/
public SuperBit(final int d, final int n, final int l, final long seed) {
this(d, n, l, new Random(seed));
}
private SuperBit(final int d, final int n, final int l, final Random rand) {
if (d <= 0) {
throw new IllegalArgumentException("Dimension d must be >= 1");
}
if (n < 1 || n > d) {
throw new IllegalArgumentException(
"Super-Bit depth N must be 1 <= N <= d");
}
if (l < 1) {
throw new IllegalArgumentException(
"Number of Super-Bit L must be >= 1");
}
// Input: Data space dimension d, Super-Bit depth 1 <= N <= d,
// number of Super-Bit L >= 1,
// resulting code length K = N * L
// Generate a random matrix H with each element sampled independently
// from the normal distribution
// N (0, 1), with each column normalized to unit length.
// Denote H = [v1, v2, ..., vK].
int code_length = n * l;
double[][] v = new double[code_length][d];
Parallel.parallelFor(0, code_length, t -> {
double[] vector = new double[d];
for (int j = 0; j < d; j++) {
vector[j] = rand.nextGaussian();
}
normalize(vector);
v[t] = vector;
});
double[][] w = new double[code_length][d];
for (int i = 0; i <= l - 1; i++) {
for (int j = 1; j <= n; j++) {
java.lang.System.arraycopy(
v[i * n + j - 1],
0,
w[i * n + j - 1],
0,
d);
for (int k = 1; k <= (j - 1); k++) {
w[i * n + j - 1] = sub(
w[i * n + j - 1],
product(
dotProduct(
w[i * n + k - 1],
v[ i * n + j - 1]),
w[i * n + k - 1]));
}
normalize(w[i * n + j - 1]);
}
}
this.hyperplanes = w;
}
/**
* Initialize SuperBit algorithm.
* With code length K = 10000
* The K vectors are orthogonalized in d batches of 10000/d vectors
* The resulting mean error is 0.01
* @param d The size of the vector we are operating on
*/
public SuperBit(final int d) {
this(d, d, DEFAULT_CODE_LENGTH / d, 8675309);
}
/**
* Initialize SuperBit algorithm without parameters
* (used only for serialization).
*/
public SuperBit() {}
/**
* Compute the signature of this vector.
* @param vector
* @return
*/
public final boolean[] signature(final double[] vector) {
boolean[] sig = new boolean[this.hyperplanes.length];
for (int i = 0; i < this.hyperplanes.length; i++) {
sig[i] = (dotProduct(this.hyperplanes[i], vector) >= 0);
}
return sig;
}
/**
* Compute the signature of the given FloatTensor, set the encapsulated signature for serialization
* @param vector The target FloatTensor
*/
public final void signature(final FloatTensor vector) {
sig = new boolean[this.hyperplanes.length];
for (int i = 0; i < this.hyperplanes.length; i++) {
sig[i] = (dotProduct(this.hyperplanes[i], vector) >= 0);
}
}
public final boolean[] getSignature() {
return sig;
}
/**
* Compute the similarity between two signature, which is also an
* estimation of the cosine similarity between the two vectors.
* @param sig1
* @param sig2
* @return estimated cosine similarity
*/
public final double similarity(final boolean[] sig1, final boolean[] sig2) {
DoubleAdder agg = new DoubleAdder(); // Thread-safe accumulator
Parallel.parallelFor(0, sig1.length, t -> {
if (sig1[t] == sig2[t]) {
agg.add(1); // Efficient atomic addition
}
});
double sim = agg.sum() / sig1.length; // Use .sum() instead of .get()
return Math.cos((1 - sim) * Math.PI);
}
/**
* Get the hyperplanes coefficients used to compute signatures.
* @return
*/
public final double[][] getHyperplanes() {
return this.hyperplanes;
}
/**
* Computes the cosine similarity, computed as v1 dot v2 / (|v1| * |v2|).
* Cosine similarity of two vectors is the cosine of the angle between them.
* It ranges between -1 and +1
*
* @param v1
* @param v2
* @return
*/
public static double cosineSimilarity(final double[]v1, final double[] v2) {
return dotProduct(v1, v2) / (norm(v1) * norm(v2));
}
private static double[] product(final double x, final double[] v) {
double[] r = new double[v.length];
Parallel.parallelFor(0, v.length, t -> {
r[t] = x * v[t];
});
return r;
}
private static double[] sub(final double[] a, final double[] b) {
double[] r = new double[a.length];
Parallel.parallelFor(0, a.length, t -> {
r[t] = a[t] - b[t];
});
return r;
}
private static void normalize(final double[] vector) {
final double norm = norm(vector);
Parallel.parallelFor(0, vector.length, t -> vector[t] /= norm);
}
/**
* Returns the norm L2. sqrt(sum_i(v_i^2))
* @param v
* @return
*/
private static double norm(final double[] v) {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v.length, t -> agg.add(v[t] * v[t]));
return Math.sqrt(agg.sum());
}
private static double dotProduct(final double[] v1, final double[] v2) {
if (v1.length < 10_000) { // Adjust threshold based on benchmarking
return IntStream.range(0, v1.length).mapToDouble(t -> v1[t] * v2[t]).sum();
} else {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v1.length, t -> agg.add(v1[t] * v2[t]));
return agg.sum();
}
}
private static double dotProduct(final double[] v1, final FloatTensor v2) {
if (v1.length < 10_000) { // Adjust threshold based on benchmarking
return IntStream.range(0, v1.length).mapToDouble(t -> v1[t] * v2.getFloat(t)).sum();
} else {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v1.length, t -> agg.add(v1[t] * v2.getFloat(t)));
return agg.sum();
}
}
@Override
public int compareTo(Object o) {
return Double.compare(similarity(this.sig, ((SuperBit)o).sig), 1.0); // Ensures proper ordering
}
}
---------- END SOURCE ----------
Tested under Win11, Ubuntu 22.04, GraalVM and OpenJDK
A DESCRIPTION OF THE PROBLEM :
IndexOutOfBoundsException on working code after integrating new working code that used ByteBuffer instead of MemorySegment issue #11294 logged on GraalVM
STEPS TO FOLLOW TO REPRODUCE THE PROBLEM :
compile_rel.bat
jar_llama3.bat
chat_qwendb.bat
NOTE: must acquire qwen2-7b-instruct-q8_0.gguf from HuggingFace etc. for proper model, fails with other models as well. Works with Mistral and Llama.
EXPECTED VERSUS ACTUAL BEHAVIOR :
EXPECTED -
Iteration through buffer backed by MemorySegment works until unknown edge case causes failure.
ACTUAL -
C:\Progra~1\Java\graalvm-jdk-25+20.1\bin\java -server -XX:+UseParallelGC -Xmn26g -Xms26g -Xmx26g --enable-preview --add-modules jdk.incubator.vector -jar Llama3.jar --model qwen2-7b-instruct-q8_0.gguf --chat -n -1
[0.006s][warning][gc,ergo] NewSize (27262976k) is equal to or greater than initial heap size (27262976k). A new NewSize of 27262464k will be used to accomodate an old generation.
[0.006s][warning][gc,ergo] MaxNewSize (27262976k) is equal to or greater than the entire heap (27262976k). A new max generation size of 27262464k will be used.
WARNING: Using incubator modules: jdk.incubator.vector
Parse qwen2-7b-instruct-q8_0.gguf: 1159 millis
GGUF metadata:
{qwen2.block_count=28, tokenizer.ggml.add_bos_token=false, qwen2.embedding_length=3584, tokenizer.ggml.padding_token_id=151643, quantize.imatrix.chunks_count=1937, qwen2.feed_forward_length=18944, quantize.imatrix.entries_count=196, qwen2.attention.layer_norm_rms_epsilon=1.0E-6, tokenizer.ggml.merges=[Ljava.lang.String;@2e5d6d97, tokenizer.ggml.pre=qwen2, qwen2.attention.head_count_kv=4, general.architecture=qwen2, tokenizer.chat_template={% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant.<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}, general.file_type=7, general.name=qwen2-7b-instruct, general.quantization_version=2, tokenizer.ggml.token_type=[I@238e0d81, tokenizer.ggml.eos_token_id=151645, tokenizer.ggml.bos_token_id=151643, quantize.imatrix.dataset=../sft_2406.txt, qwen2.rope.freq_base=1000000.0, tokenizer.ggml.tokens=[Ljava.lang.String;@31221be2, tokenizer.ggml.model=gpt2, qwen2.context_length=32768, quantize.imatrix.file=../Qwen2/gguf/qwen2-7b-imatrix/imatrix.dat, qwen2.attention.head_count=28}
Tensor:blk.11.attn_q.bias=blk.11.attn_q.bias offset:2575927296 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_norm.weight=blk.3.attn_norm.weight offset:1322035200 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_v.weight=blk.19.attn_v.weight offset:5530279936 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.10.attn_q.bias=blk.10.attn_q.bias offset:2328268800 dims:[3584] number elems:3584 size:14336
Tensor:blk.12.attn_q.bias=blk.12.attn_q.bias offset:2823585792 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_q.bias=blk.13.attn_q.bias offset:3071244288 dims:[3584] number elems:3584 size:14336
Tensor:blk.8.attn_output.weight=blk.8.attn_output.weight offset:3872710656 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.15.attn_output.weight=blk.15.attn_output.weight offset:4512333824 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.11.attn_norm.weight=blk.11.attn_norm.weight offset:2343882752 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_q.bias=blk.19.attn_q.bias offset:5516615680 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_q.bias=blk.18.attn_q.bias offset:5268957184 dims:[3584] number elems:3584 size:14336
Tensor:blk.17.attn_q.bias=blk.17.attn_q.bias offset:5021298688 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_v.weight=blk.23.attn_v.weight offset:7099973632 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_output.weight=blk.0.attn_output.weight offset:797456384 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.10.ffn_norm.weight=blk.10.ffn_norm.weight offset:2312654848 dims:[3584] number elems:3584 size:14336
Tensor:blk.15.attn_q.bias=blk.15.attn_q.bias offset:4525981696 dims:[3584] number elems:3584 size:14336
Tensor:blk.14.attn_q.bias=blk.14.attn_q.bias offset:3174596608 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.attn_q.bias=blk.16.attn_q.bias offset:4773640192 dims:[3584] number elems:3584 size:14336
Tensor:blk.11.attn_q.weight=blk.11.attn_q.weight offset:2575941632 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.ffn_norm.weight=blk.16.ffn_norm.weight offset:4758026240 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.ffn_norm.weight=blk.13.ffn_norm.weight offset:3055630336 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.attn_q.weight=blk.26.attn_q.weight offset:7829299200 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.11.attn_k.weight=blk.11.attn_k.weight offset:2560329728 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.16.attn_k.weight=blk.16.attn_k.weight offset:4758042624 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.14.attn_v.weight=blk.14.attn_v.weight offset:3188260864 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.16.attn_q.weight=blk.16.attn_q.weight offset:4773654528 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_q.weight=blk.17.attn_q.weight offset:5021313024 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_q.bias=blk.22.attn_q.bias offset:6187423744 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.ffn_down.weight=blk.26.ffn_down.weight offset:7597254656 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.8.attn_norm.weight=blk.8.attn_norm.weight offset:3654313984 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_q.bias=blk.9.attn_q.bias offset:4134017024 dims:[3584] number elems:3584 size:14336
Tensor:blk.21.attn_q.bias=blk.21.attn_q.bias offset:6011932672 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_q.bias=blk.23.attn_q.bias offset:7086309376 dims:[3584] number elems:3584 size:14336
Tensor:blk.23.attn_k.bias=blk.23.attn_k.bias offset:7070709760 dims:[512] number elems:512 size:2048
Tensor:blk.25.attn_k.bias=blk.25.attn_k.bias offset:7566026752 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_q.bias=blk.8.attn_q.bias offset:3886358528 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.attn_q.bias=blk.20.attn_q.bias offset:5764274176 dims:[3584] number elems:3584 size:14336
Tensor:output.weight=output.weight offset:6203037696 dims:[3584, 152064] number elems:544997376 size:579059712
Tensor:blk.24.attn_q.bias=blk.24.attn_q.bias offset:7333967872 dims:[3584] number elems:3584 size:14336
Tensor:blk.17.ffn_up.weight=blk.17.ffn_up.weight offset:4933545984 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.3.attn_q.weight=blk.3.attn_q.weight offset:1554094080 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.8.ffn_up.weight=blk.8.ffn_up.weight offset:3798605824 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.22.attn_k.bias=blk.22.attn_k.bias offset:6171824128 dims:[512] number elems:512 size:2048
Tensor:blk.26.attn_k.bias=blk.26.attn_k.bias offset:7813685248 dims:[512] number elems:512 size:2048
Tensor:blk.7.attn_k.bias=blk.7.attn_k.bias offset:3623100416 dims:[512] number elems:512 size:2048
Tensor:blk.9.attn_k.bias=blk.9.attn_k.bias offset:4118417408 dims:[512] number elems:512 size:2048
Tensor:blk.20.ffn_down.weight=blk.20.ffn_down.weight offset:5532243968 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.21.attn_k.bias=blk.21.attn_k.bias offset:5996333056 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_down.weight=blk.23.ffn_down.weight offset:6854279168 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.27.attn_k.bias=blk.27.attn_k.bias offset:8061343744 dims:[512] number elems:512 size:2048
Tensor:blk.11.ffn_gate.weight=blk.11.ffn_gate.weight offset:2416035840 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.ffn_gate.weight=blk.14.ffn_gate.weight offset:3086858240 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_q.weight=blk.25.attn_q.weight offset:7581640704 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.8.attn_k.bias=blk.8.attn_k.bias offset:3870758912 dims:[512] number elems:512 size:2048
Tensor:blk.17.ffn_gate.weight=blk.17.ffn_gate.weight offset:4861407232 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.20.attn_k.bias=blk.20.attn_k.bias offset:5748674560 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_q.bias=blk.2.attn_q.bias offset:1306421248 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_q.bias=blk.1.attn_q.bias offset:1058762752 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.attn_v.weight=blk.7.attn_v.weight offset:3652364288 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_q.bias=blk.0.attn_q.bias offset:811104256 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_q.weight=blk.4.attn_q.weight offset:1801752576 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_norm.weight=blk.16.attn_norm.weight offset:4541595648 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.ffn_norm.weight=blk.19.ffn_norm.weight offset:5501001728 dims:[3584] number elems:3584 size:14336
Tensor:blk.21.attn_k.weight=blk.21.attn_k.weight offset:5996335104 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.20.ffn_up.weight=blk.20.ffn_up.weight offset:5676521472 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_q.bias=blk.25.attn_q.bias offset:7581626368 dims:[3584] number elems:3584 size:14336
Tensor:blk.27.attn_q.bias=blk.27.attn_q.bias offset:8076943360 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_k.weight=blk.9.attn_k.weight offset:4118419456 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.22.ffn_norm.weight=blk.22.ffn_norm.weight offset:6854250496 dims:[3584] number elems:3584 size:14336
Tensor:blk.24.attn_k.bias=blk.24.attn_k.bias offset:7318368256 dims:[512] number elems:512 size:2048
Tensor:blk.25.ffn_norm.weight=blk.25.ffn_norm.weight offset:7566012416 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.attn_q.bias=blk.26.attn_q.bias offset:7829284864 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.ffn_up.weight=blk.16.ffn_up.weight offset:4685887488 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_gate.weight=blk.5.ffn_gate.weight offset:1889505280 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.8.ffn_gate.weight=blk.8.ffn_gate.weight offset:3726467072 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.2.ffn_gate.weight=blk.2.ffn_gate.weight offset:1146529792 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.21.attn_norm.weight=blk.21.attn_norm.weight offset:5779888128 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.ffn_norm.weight=blk.1.ffn_norm.weight offset:1043148800 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_k.bias=blk.1.attn_k.bias offset:1043163136 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_v.weight=blk.6.attn_v.weight offset:2094274560 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_k.bias=blk.0.attn_k.bias offset:795504640 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_k.weight=blk.8.attn_k.weight offset:3870760960 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.21.ffn_up.weight=blk.21.ffn_up.weight offset:5924179968 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_norm.weight=blk.4.ffn_norm.weight offset:1786124288 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_q.bias=blk.3.attn_q.bias offset:1554079744 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.attn_k.bias=blk.3.attn_k.bias offset:1538480128 dims:[512] number elems:512 size:2048
Tensor:blk.5.attn_k.bias=blk.5.attn_k.bias offset:2033797120 dims:[512] number elems:512 size:2048
Tensor:blk.9.ffn_up.weight=blk.9.ffn_up.weight offset:4046264320 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.attn_q.bias=blk.4.attn_q.bias offset:1801738240 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.ffn_norm.weight=blk.7.ffn_norm.weight offset:3623086080 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.attn_k.bias=blk.2.attn_k.bias offset:1290821632 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_k.bias=blk.6.attn_k.bias offset:2065010688 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_q.bias=blk.6.attn_q.bias offset:2080610304 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.attn_q.bias=blk.5.attn_q.bias offset:2049396736 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.attn_q.bias=blk.7.attn_q.bias offset:3638700032 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.attn_k.weight=blk.20.attn_k.weight offset:5748676608 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.4.attn_k.bias=blk.4.attn_k.bias offset:1786138624 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_gate.weight=blk.23.ffn_gate.weight offset:6926417920 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.ffn_down.weight=blk.11.ffn_down.weight offset:2343897088 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.10.ffn_up.weight=blk.10.ffn_up.weight offset:2240516096 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.12.attn_k.bias=blk.12.attn_k.bias offset:2807986176 dims:[512] number elems:512 size:2048
Tensor:blk.14.attn_k.bias=blk.14.attn_k.bias offset:3158996992 dims:[512] number elems:512 size:2048
Tensor:blk.14.ffn_down.weight=blk.14.ffn_down.weight offset:4149645312 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.26.ffn_gate.weight=blk.26.ffn_gate.weight offset:7669393408 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.attn_v.weight=blk.27.attn_v.weight offset:8090607616 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_v.weight=blk.5.attn_v.weight offset:2063060992 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.10.attn_q.weight=blk.10.attn_q.weight offset:2328283136 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_norm.weight=blk.19.attn_norm.weight offset:5284571136 dims:[3584] number elems:3584 size:14336
Tensor:blk.20.ffn_gate.weight=blk.20.ffn_gate.weight offset:5604382720 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.10.attn_k.bias=blk.10.attn_k.bias offset:2312669184 dims:[512] number elems:512 size:2048
Tensor:blk.13.attn_output.weight=blk.13.attn_output.weight offset:3057596416 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_k.bias=blk.16.attn_k.bias offset:4758040576 dims:[512] number elems:512 size:2048
Tensor:blk.18.attn_k.bias=blk.18.attn_k.bias offset:5253357568 dims:[512] number elems:512 size:2048
Tensor:blk.27.attn_norm.weight=blk.27.attn_norm.weight offset:7844898816 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_v.bias=blk.18.attn_v.bias offset:5282619392 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_norm.weight=blk.2.attn_norm.weight offset:1074376704 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.attn_k.weight=blk.2.attn_k.weight offset:1290823680 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.ffn_up.weight=blk.27.ffn_up.weight offset:7989190656 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.7.attn_k.weight=blk.7.attn_k.weight offset:3623102464 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.ffn_down.weight=blk.17.ffn_down.weight offset:4789268480 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.0.attn_q.weight=blk.0.attn_q.weight offset:811118592 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.ffn_up.weight=blk.22.ffn_up.weight offset:6099685376 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.0.attn_v.weight=blk.0.attn_v.weight offset:824768512 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_k.weight=blk.15.attn_k.weight offset:4510384128 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_v.weight=blk.15.attn_v.weight offset:4539645952 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_q.weight=blk.15.attn_q.weight offset:4525996032 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.24.ffn_up.weight=blk.24.ffn_up.weight offset:7246215168 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.8.ffn_down.weight=blk.8.ffn_down.weight offset:3654328320 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.2.attn_q.weight=blk.2.attn_q.weight offset:1306435584 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.18.attn_v.weight=blk.18.attn_v.weight offset:5282621440 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_v.weight=blk.13.attn_v.weight offset:3084908544 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.attn_norm.weight=blk.17.attn_norm.weight offset:4789254144 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.attn_norm.weight=blk.0.attn_norm.weight offset:579059712 dims:[3584] number elems:3584 size:14336
Tensor:blk.6.attn_norm.weight=blk.6.attn_norm.weight offset:3190210560 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.attn_v.bias=blk.9.attn_v.bias offset:4147679232 dims:[512] number elems:512 size:2048
Tensor:blk.21.attn_v.bias=blk.21.attn_v.bias offset:6025594880 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_v.bias=blk.10.attn_v.bias offset:2341931008 dims:[512] number elems:512 size:2048
Tensor:blk.25.attn_norm.weight=blk.25.attn_norm.weight offset:7349581824 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.attn_k.weight=blk.5.attn_k.weight offset:2033799168 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.23.attn_v.bias=blk.23.attn_v.bias offset:7099971584 dims:[512] number elems:512 size:2048
Tensor:blk.12.attn_v.bias=blk.12.attn_v.bias offset:2837248000 dims:[512] number elems:512 size:2048
Tensor:blk.27.attn_q.weight=blk.27.attn_q.weight offset:8076957696 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.13.ffn_up.weight=blk.13.ffn_up.weight offset:2983491584 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_down.weight=blk.5.ffn_down.weight offset:1817366528 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.27.attn_v.bias=blk.27.attn_v.bias offset:8090605568 dims:[512] number elems:512 size:2048
Tensor:blk.16.attn_v.bias=blk.16.attn_v.bias offset:4787302400 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_norm.weight=blk.22.attn_norm.weight offset:6782097408 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.attn_v.bias=blk.25.attn_v.bias offset:7595288576 dims:[512] number elems:512 size:2048
Tensor:blk.2.ffn_down.weight=blk.2.ffn_down.weight offset:1074391040 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.14.attn_v.bias=blk.14.attn_v.bias offset:3188258816 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_output.weight=blk.10.attn_output.weight offset:2314620928 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.attn_v.weight=blk.3.attn_v.weight offset:1567744000 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_output.weight=blk.12.attn_output.weight offset:2809937920 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.12.attn_q.weight=blk.12.attn_q.weight offset:2823600128 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.14.attn_norm.weight=blk.14.attn_norm.weight offset:4149630976 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_up.weight=blk.25.ffn_up.weight offset:7493873664 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.attn_output.weight=blk.14.attn_output.weight offset:3160948736 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.4.attn_k.weight=blk.4.attn_k.weight offset:1786140672 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.ffn_up.weight=blk.12.ffn_up.weight offset:2735833088 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.16.attn_output.weight=blk.16.attn_output.weight offset:4759992320 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_v.weight=blk.17.attn_v.weight offset:5034962944 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_q.weight=blk.13.attn_q.weight offset:3071258624 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.17.attn_k.weight=blk.17.attn_k.weight offset:5005701120 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.7.attn_v.bias=blk.7.attn_v.bias offset:3652362240 dims:[512] number elems:512 size:2048
Tensor:blk.2.attn_v.weight=blk.2.attn_v.weight offset:1320085504 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_v.bias=blk.5.attn_v.bias offset:2063058944 dims:[512] number elems:512 size:2048
Tensor:blk.3.attn_v.bias=blk.3.attn_v.bias offset:1567741952 dims:[512] number elems:512 size:2048
Tensor:blk.23.attn_output.weight=blk.23.attn_output.weight offset:7072661504 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_output.weight=blk.22.attn_output.weight offset:6173775872 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.24.attn_output.weight=blk.24.attn_output.weight offset:7320320000 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.25.attn_output.weight=blk.25.attn_output.weight offset:7567978496 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_output.weight=blk.19.attn_output.weight offset:5502967808 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.attn_output.weight=blk.26.attn_output.weight offset:7815636992 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.ffn_up.weight=blk.26.ffn_up.weight offset:7741532160 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.1.attn_v.bias=blk.1.attn_v.bias offset:1072424960 dims:[512] number elems:512 size:2048
Tensor:blk.11.ffn_up.weight=blk.11.ffn_up.weight offset:2488174592 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.17.attn_output.weight=blk.17.attn_output.weight offset:5007650816 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.18.attn_output.weight=blk.18.attn_output.weight offset:5255309312 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.19.attn_k.weight=blk.19.attn_k.weight offset:5501018112 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.14.attn_q.weight=blk.14.attn_q.weight offset:3174610944 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.16.attn_v.weight=blk.16.attn_v.weight offset:4787304448 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.18.attn_k.weight=blk.18.attn_k.weight offset:5253359616 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.3.attn_k.weight=blk.3.attn_k.weight offset:1538482176 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.attn_output.weight=blk.27.attn_output.weight offset:8063295488 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.21.attn_output.weight=blk.21.attn_output.weight offset:5998284800 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.20.attn_output.weight=blk.20.attn_output.weight offset:5750626304 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.1.attn_v.weight=blk.1.attn_v.weight offset:1072427008 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.ffn_up.weight=blk.1.ffn_up.weight offset:971010048 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.ffn_norm.weight=blk.24.ffn_norm.weight offset:7318353920 dims:[3584] number elems:3584 size:14336
Tensor:blk.11.attn_output.weight=blk.11.attn_output.weight offset:2562279424 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.23.ffn_up.weight=blk.23.ffn_up.weight offset:6998556672 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.23.attn_k.weight=blk.23.attn_k.weight offset:7070711808 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.19.ffn_gate.weight=blk.19.ffn_gate.weight offset:5356724224 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.ffn_norm.weight=blk.27.ffn_norm.weight offset:8061329408 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_output.weight=blk.4.attn_output.weight offset:1788090368 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_q.weight=blk.6.attn_q.weight offset:2080624640 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_k.weight=blk.6.attn_k.weight offset:2065012736 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.21.ffn_norm.weight=blk.21.ffn_norm.weight offset:5996318720 dims:[3584] number elems:3584 size:14336
Tensor:blk.6.ffn_up.weight=blk.6.ffn_up.weight offset:3334502400 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.19.attn_q.weight=blk.19.attn_q.weight offset:5516630016 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.9.ffn_down.weight=blk.9.ffn_down.weight offset:3901986816 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.6.ffn_down.weight=blk.6.ffn_down.weight offset:3190224896 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.13.ffn_gate.weight=blk.13.ffn_gate.weight offset:2911352832 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.10.ffn_gate.weight=blk.10.ffn_gate.weight offset:2168377344 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.14.ffn_up.weight=blk.14.ffn_up.weight offset:4221784064 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.21.ffn_down.weight=blk.21.ffn_down.weight offset:5779902464 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.24.ffn_down.weight=blk.24.ffn_down.weight offset:7101937664 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.23.attn_q.weight=blk.23.attn_q.weight offset:7086323712 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.27.ffn_down.weight=blk.27.ffn_down.weight offset:7844913152 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.16.ffn_gate.weight=blk.16.ffn_gate.weight offset:4613748736 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.attn_norm.weight=blk.5.attn_norm.weight offset:1817352192 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_norm.weight=blk.1.attn_norm.weight offset:826718208 dims:[3584] number elems:3584 size:14336
Tensor:blk.4.attn_v.weight=blk.4.attn_v.weight offset:1815402496 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.18.ffn_down.weight=blk.18.ffn_down.weight offset:5036926976 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.6.ffn_norm.weight=blk.6.ffn_norm.weight offset:3406641152 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.ffn_norm.weight=blk.9.ffn_norm.weight offset:4118403072 dims:[3584] number elems:3584 size:14336
Tensor:blk.3.ffn_norm.weight=blk.3.ffn_norm.weight offset:1538465792 dims:[3584] number elems:3584 size:14336
Tensor:blk.22.ffn_gate.weight=blk.22.ffn_gate.weight offset:6027546624 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.15.ffn_down.weight=blk.15.ffn_down.weight offset:4293951488 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.12.ffn_down.weight=blk.12.ffn_down.weight offset:2591555584 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.0.ffn_norm.weight=blk.0.ffn_norm.weight offset:795490304 dims:[3584] number elems:3584 size:14336
Tensor:blk.1.attn_k.weight=blk.1.attn_k.weight offset:1043165184 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.attn_q.weight=blk.1.attn_q.weight offset:1058777088 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.7.ffn_gate.weight=blk.7.ffn_gate.weight offset:3478808576 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_gate.weight=blk.4.ffn_gate.weight offset:1641846784 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.1.ffn_gate.weight=blk.1.ffn_gate.weight offset:898871296 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.25.attn_v.weight=blk.25.attn_v.weight offset:7595290624 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_v.weight=blk.12.attn_v.weight offset:2837250048 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.ffn_norm.weight=blk.12.ffn_norm.weight offset:2807971840 dims:[3584] number elems:3584 size:14336
Tensor:blk.15.ffn_norm.weight=blk.15.ffn_norm.weight offset:4510367744 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.ffn_norm.weight=blk.18.ffn_norm.weight offset:5253343232 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_gate.weight=blk.25.ffn_gate.weight offset:7421734912 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.9.attn_norm.weight=blk.9.attn_norm.weight offset:3901972480 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_k.weight=blk.13.attn_k.weight offset:3055646720 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.attn_k.weight=blk.0.attn_k.weight offset:795506688 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.24.attn_norm.weight=blk.24.attn_norm.weight offset:7101923328 dims:[3584] number elems:3584 size:14336
Tensor:blk.24.attn_q.weight=blk.24.attn_q.weight offset:7333982208 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.ffn_down.weight=blk.3.ffn_down.weight offset:1322049536 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.18.attn_q.weight=blk.18.attn_q.weight offset:5268971520 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.0.ffn_down.weight=blk.0.ffn_down.weight offset:579074048 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.11.attn_v.weight=blk.11.attn_v.weight offset:2589591552 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.26.attn_v.weight=blk.26.attn_v.weight offset:7842949120 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.0.ffn_up.weight=blk.0.ffn_up.weight offset:723351552 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.13.attn_norm.weight=blk.13.attn_norm.weight offset:2839199744 dims:[3584] number elems:3584 size:14336
Tensor:blk.14.attn_k.weight=blk.14.attn_k.weight offset:3158999040 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.ffn_up.weight=blk.15.ffn_up.weight offset:4438228992 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.attn_k.bias=blk.11.attn_k.bias offset:2560327680 dims:[512] number elems:512 size:2048
Tensor:blk.15.attn_k.bias=blk.15.attn_k.bias offset:4510382080 dims:[512] number elems:512 size:2048
Tensor:blk.17.attn_v.bias=blk.17.attn_v.bias offset:5034960896 dims:[512] number elems:512 size:2048
Tensor:blk.4.attn_norm.weight=blk.4.attn_norm.weight offset:1569693696 dims:[3584] number elems:3584 size:14336
Tensor:blk.10.attn_norm.weight=blk.10.attn_norm.weight offset:2096224256 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.attn_v.bias=blk.19.attn_v.bias offset:5530277888 dims:[512] number elems:512 size:2048
Tensor:blk.6.attn_output.weight=blk.6.attn_output.weight offset:2066962432 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.10.attn_k.weight=blk.10.attn_k.weight offset:2312671232 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.17.attn_k.bias=blk.17.attn_k.bias offset:5005699072 dims:[512] number elems:512 size:2048
Tensor:blk.10.attn_v.weight=blk.10.attn_v.weight offset:2341933056 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.attn_output.weight=blk.2.attn_output.weight offset:1292773376 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.20.attn_norm.weight=blk.20.attn_norm.weight offset:5532229632 dims:[3584] number elems:3584 size:14336
Tensor:blk.8.ffn_norm.weight=blk.8.ffn_norm.weight offset:3870744576 dims:[3584] number elems:3584 size:14336
Tensor:blk.2.ffn_norm.weight=blk.2.ffn_norm.weight offset:1290807296 dims:[3584] number elems:3584 size:14336
Tensor:blk.5.ffn_norm.weight=blk.5.ffn_norm.weight offset:2033782784 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.attn_k.bias=blk.13.attn_k.bias offset:3055644672 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_k.weight=blk.22.attn_k.weight offset:6171826176 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.12.attn_norm.weight=blk.12.attn_norm.weight offset:2591541248 dims:[3584] number elems:3584 size:14336
Tensor:blk.7.ffn_up.weight=blk.7.ffn_up.weight offset:3550947328 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.27.attn_k.weight=blk.27.attn_k.weight offset:8061345792 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.8.attn_v.bias=blk.8.attn_v.bias offset:3900020736 dims:[512] number elems:512 size:2048
Tensor:blk.20.attn_v.bias=blk.20.attn_v.bias offset:5777936384 dims:[512] number elems:512 size:2048
Tensor:blk.11.attn_v.bias=blk.11.attn_v.bias offset:2589589504 dims:[512] number elems:512 size:2048
Tensor:blk.5.attn_q.weight=blk.5.attn_q.weight offset:2049411072 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.12.attn_k.weight=blk.12.attn_k.weight offset:2807988224 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.13.attn_v.bias=blk.13.attn_v.bias offset:3084906496 dims:[512] number elems:512 size:2048
Tensor:blk.22.attn_q.weight=blk.22.attn_q.weight offset:6187438080 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_v.bias=blk.22.attn_v.bias offset:6201085952 dims:[512] number elems:512 size:2048
Tensor:blk.24.attn_v.bias=blk.24.attn_v.bias offset:7347630080 dims:[512] number elems:512 size:2048
Tensor:blk.24.attn_v.weight=blk.24.attn_v.weight offset:7347632128 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.19.attn_k.bias=blk.19.attn_k.bias offset:5501016064 dims:[512] number elems:512 size:2048
Tensor:blk.8.attn_v.weight=blk.8.attn_v.weight offset:3900022784 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.ffn_up.weight=blk.2.ffn_up.weight offset:1218668544 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.26.attn_v.bias=blk.26.attn_v.bias offset:7842947072 dims:[512] number elems:512 size:2048
Tensor:blk.20.attn_v.weight=blk.20.attn_v.weight offset:5777938432 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.attn_v.bias=blk.15.attn_v.bias offset:4539643904 dims:[512] number elems:512 size:2048
Tensor:blk.1.ffn_down.weight=blk.1.ffn_down.weight offset:826732544 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.8.attn_q.weight=blk.8.attn_q.weight offset:3886372864 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.15.attn_norm.weight=blk.15.attn_norm.weight offset:4293937152 dims:[3584] number elems:3584 size:14336
Tensor:blk.9.ffn_gate.weight=blk.9.ffn_gate.weight offset:3974125568 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_down.weight=blk.4.ffn_down.weight offset:1569708032 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.7.ffn_down.weight=blk.7.ffn_down.weight offset:3406669824 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:token_embd.weight=token_embd.weight offset:0 dims:[3584, 152064] number elems:544997376 size:579059712
Tensor:blk.3.attn_output.weight=blk.3.attn_output.weight offset:1540431872 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.26.attn_k.weight=blk.26.attn_k.weight offset:7813687296 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.5.attn_output.weight=blk.5.attn_output.weight offset:2035748864 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.9.attn_output.weight=blk.9.attn_output.weight offset:4120369152 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.7.attn_output.weight=blk.7.attn_output.weight offset:3625052160 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.23.attn_norm.weight=blk.23.attn_norm.weight offset:6854264832 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.attn_k.weight=blk.25.attn_k.weight offset:7566028800 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.1.attn_output.weight=blk.1.attn_output.weight offset:1045114880 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.3.ffn_gate.weight=blk.3.ffn_gate.weight offset:1394188288 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.6.ffn_gate.weight=blk.6.ffn_gate.weight offset:3262363648 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.3.ffn_up.weight=blk.3.ffn_up.weight offset:1466327040 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.18.ffn_up.weight=blk.18.ffn_up.weight offset:5181204480 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.9.attn_q.weight=blk.9.attn_q.weight offset:4134031360 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.6.attn_v.bias=blk.6.attn_v.bias offset:2094272512 dims:[512] number elems:512 size:2048
Tensor:blk.21.attn_q.weight=blk.21.attn_q.weight offset:6011947008 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.21.attn_v.weight=blk.21.attn_v.weight offset:6025596928 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.15.ffn_gate.weight=blk.15.ffn_gate.weight offset:4366090240 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.attn_v.bias=blk.4.attn_v.bias offset:1815400448 dims:[512] number elems:512 size:2048
Tensor:blk.9.attn_v.weight=blk.9.attn_v.weight offset:4147681280 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.20.ffn_norm.weight=blk.20.ffn_norm.weight offset:5748660224 dims:[3584] number elems:3584 size:14336
Tensor:blk.12.ffn_gate.weight=blk.12.ffn_gate.weight offset:2663694336 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.7.attn_q.weight=blk.7.attn_q.weight offset:3638714368 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:output_norm.weight=output_norm.weight offset:8092557312 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.attn_v.bias=blk.0.attn_v.bias offset:824766464 dims:[512] number elems:512 size:2048
Tensor:blk.23.ffn_norm.weight=blk.23.ffn_norm.weight offset:7070695424 dims:[3584] number elems:3584 size:14336
Tensor:blk.0.ffn_gate.weight=blk.0.ffn_gate.weight offset:651212800 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.5.ffn_up.weight=blk.5.ffn_up.weight offset:1961644032 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.attn_k.weight=blk.24.attn_k.weight offset:7318370304 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.2.attn_v.bias=blk.2.attn_v.bias offset:1320083456 dims:[512] number elems:512 size:2048
Tensor:blk.19.ffn_up.weight=blk.19.ffn_up.weight offset:5428862976 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.4.ffn_up.weight=blk.4.ffn_up.weight offset:1713985536 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.11.ffn_norm.weight=blk.11.ffn_norm.weight offset:2560313344 dims:[3584] number elems:3584 size:14336
Tensor:blk.26.ffn_norm.weight=blk.26.ffn_norm.weight offset:7813670912 dims:[3584] number elems:3584 size:14336
Tensor:blk.13.ffn_down.weight=blk.13.ffn_down.weight offset:2839214080 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.14.ffn_norm.weight=blk.14.ffn_norm.weight offset:4293922816 dims:[3584] number elems:3584 size:14336
Tensor:blk.16.ffn_down.weight=blk.16.ffn_down.weight offset:4541609984 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.7.attn_norm.weight=blk.7.attn_norm.weight offset:3406655488 dims:[3584] number elems:3584 size:14336
Tensor:blk.18.attn_norm.weight=blk.18.attn_norm.weight offset:5036912640 dims:[3584] number elems:3584 size:14336
Tensor:blk.10.ffn_down.weight=blk.10.ffn_down.weight offset:2096238592 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.17.ffn_norm.weight=blk.17.ffn_norm.weight offset:5005684736 dims:[3584] number elems:3584 size:14336
Tensor:blk.19.ffn_down.weight=blk.19.ffn_down.weight offset:5284585472 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.26.attn_norm.weight=blk.26.attn_norm.weight offset:7597240320 dims:[3584] number elems:3584 size:14336
Tensor:blk.25.ffn_down.weight=blk.25.ffn_down.weight offset:7349596160 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.18.ffn_gate.weight=blk.18.ffn_gate.weight offset:5109065728 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.20.attn_q.weight=blk.20.attn_q.weight offset:5764288512 dims:[3584, 3584] number elems:12845056 size:13647872
Tensor:blk.22.attn_v.weight=blk.22.attn_v.weight offset:6201088000 dims:[3584, 512] number elems:1835008 size:1949696
Tensor:blk.27.ffn_gate.weight=blk.27.ffn_gate.weight offset:7917051904 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.22.ffn_down.weight=blk.22.ffn_down.weight offset:6782111744 dims:[18944, 3584] number elems:67895296 size:72138752
Tensor:blk.21.ffn_gate.weight=blk.21.ffn_gate.weight offset:5852041216 dims:[3584, 18944] number elems:67895296 size:72138752
Tensor:blk.24.ffn_gate.weight=blk.24.ffn_gate.weight offset:7174076416 dims:[3584, 18944] number elems:67895296 size:72138752
Load model: 332 millis
>hi
setFloat:0 of size:3584
setFloat:1 of size:3584
setFloat:2 of size:3584
setFloat:...SNIP
Exception in thread "main" java.lang.IndexOutOfBoundsException: Out of bound access on segment MemorySegment{ kind: mapped, address: 0x217e96b3f20, byteSize: 2048 }; new offset = 2048; new length = 4
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.outOfBoundException(AbstractMemorySegmentImpl.java:433)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.apply(AbstractMemorySegmentImpl.java:414)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.apply(AbstractMemorySegmentImpl.java:70)
at java.base/jdk.internal.util.Preconditions.outOfBounds(Preconditions.java:98)
at java.base/jdk.internal.util.Preconditions.outOfBoundsCheckIndex(Preconditions.java:124)
at java.base/jdk.internal.util.Preconditions.checkIndex(Preconditions.java:448)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkBounds(AbstractMemorySegmentImpl.java:403)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkAccess(AbstractMemorySegmentImpl.java:357)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.checkEnclosingLayout(AbstractMemorySegmentImpl.java:362)
at java.base/java.lang.invoke.SegmentVarHandle.checkSegment(SegmentVarHandle.java:92)
at java.base/java.lang.invoke.VarHandleSegmentAsFloats.get(VarHandleSegmentAsFloats.java:59)
at java.base/java.lang.invoke.VarHandleSegmentAsFloats.get(VarHandleSegmentAsFloats.java:53)
at java.base/jdk.internal.foreign.AbstractMemorySegmentImpl.get(AbstractMemorySegmentImpl.java:754)
at com.llama4j.FloatTensor.readFloat(Llama3.java:2174)
at com.llama4j.F32FloatTensor.getFloat(Llama3.java:2946)
at com.llama4j.FloatTensor.lambda$addInPlace$0(Llama3.java:2320)
at com.llama4j.FloatTensor.mapWithIndexInPlace(Llama3.java:2314)
at com.llama4j.FloatTensor.addInPlace(Llama3.java:2320)
at com.llama4j.FloatTensor.addInPlace(Llama3.java:2324)
at com.llama4j.Llama.forwardQwen(Llama3.java:1392)
at com.llama4j.Llama.generateTokensQwen(Llama3.java:1596)
at com.llama4j.Llama3.runInteractive(Llama3.java:149)
at com.llama4j.Llama3.main(Llama3.java:338)
---------- BEGIN SOURCE ----------
///usr/bin/env jbang "$0" "$@" ; exit $?
//JAVA 21+
//PREVIEW
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0
//MAIN com.llama4j.Llama3
// Practical Llama 3 (and 3.1) inference in a single Java file
// Author: Alfonso² Peterssen
// Based on Andrej Karpathy's llama2.c and minbpe projects
//
// Supports llama.cpp's GGUF format, restricted to Q4_0 and Q8_0 quantized models
// Multi-threaded matrix vector multiplication routines implemented using Java's Vector API
// Simple CLI with --chat and --instruct mode
//
// To run just:
// jbang Llama3.java --help
//
// Remember: Llama models use GPT2 vocabulary while non-Llama models use Llama vocabulary!
// Enjoy!
package com.llama4j;
import jdk.incubator.vector.*;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import java.util.function.LongConsumer;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
public class Llama3 {
// Batch-size used in prompt evaluation.
private static int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
private final static boolean DEBUG = false;
static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
Sampler sampler;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
sampler = Sampler.ARGMAX;
} else {
// we sample from this distribution to get the next token
RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
Sampler innerSampler;
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
innerSampler = new CategoricalSampler(rng);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
innerSampler = new ToppSampler(vocabularySize, topp, rng);
}
sampler = logits -> {
// apply the temperature to the logits
logits.divideInPlace(0, logits.size(), temperature);
// apply softmax to the logits to get the probabilities for next token
logits.softmaxInPlace(0, logits.size());
return innerSampler.sampleToken(logits);
};
}
return sampler;
}
static void runInteractive(Llama model, Sampler sampler, Options options) {
Llama.State state = null;
List<Integer> conversationTokens = new ArrayList<>();
ChatFormatInterface chatFormat;
// Chat format seems solely based on individual model, so we extract a name in model loader from Metada general.name
if(ModelLoader.name.equals("mistral")) {
chatFormat = new MistralChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("llama")) {
chatFormat = new ChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("qwen")) {
BATCH_SIZE = 1;
chatFormat = new ChatMLFormat(model.tokenizer());
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
conversationTokens.add(chatFormat.getBeginOfText());
if (options.systemPrompt() != null) {
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
int startPosition = 0;
Scanner in = new Scanner(System.in);
loop: while (true) {
boolean storeDb = true;
System.out.print("> ");
System.out.flush();
String userText = in.nextLine();
switch (userText) {
case "/quit":
case "/exit": break loop;
case "/context": {
System.out.printf("%d out of %d context tokens used (%d tokens remaining)%n",
conversationTokens.size(),
options.maxTokens(),
options.maxTokens() - conversationTokens.size());
continue;
}
}
if (state == null) {
state = model.createNewState(BATCH_SIZE, chatFormat.getBeginOfText());
}
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set<Integer> stopTokens = chatFormat.getStopTokens();
List<Integer> responseTokens;
if(ModelLoader.name.equals("qwen")) {
responseTokens = Llama.generateTokensQwen(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
int tokenType = model.tokenizer().getTokenType(token);
if (tokenType == 1 || tokenType == 6) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
} else {
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
}
// Include stop token in the prompt history, but not in the response displayed to the user.
conversationTokens.addAll(responseTokens);
startPosition = conversationTokens.size();
Integer stopToken = null;
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
stopToken = responseTokens.getLast();
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
if (stopToken == null) {
System.err.println("Ran out of context length...");
break;
}
}
}
}
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
ChatFormatInterface chatFormat;
// Chat format seems solely based on individual model, so we extract a name in model loader from Metada general.name
if(ModelLoader.name.equals("mistral")) {
chatFormat = new MistralChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("llama")) {
chatFormat = new ChatFormat(model.tokenizer());
} else {
if(ModelLoader.name.equals("qwen")) {
chatFormat = new ChatMLFormat(model.tokenizer());
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
Llama.State state = model.createNewState(BATCH_SIZE, chatFormat.getBeginOfText());
List<Integer> promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.getBeginOfText());
if (options.systemPrompt() != null) {
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set<Integer> stopTokens = chatFormat.getStopTokens();
List<Integer> responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}
}
record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
String localNode, String remoteNode, int remotePort) {
static final int DEFAULT_MAX_TOKENS = 512;
Options {
require(modelPath != null, "Missing argument: --model <path> is required");
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
static void require(boolean condition, String messageFormat, Object... args) {
if (!condition) {
System.out.println("ERROR " + messageFormat.formatted(args));
System.out.println();
printUsage(System.out);
System.exit(-1);
}
}
static void printUsage(PrintStream out) {
out.println("Usage: jbang Llama3.java [options]");
out.println();
out.println("Options:");
out.println(" --model, -m <path> required, path to .gguf file");
out.println(" --interactive, --chat, -i run in chat mode");
out.println(" --instruct run in instruct (once) mode, default mode");
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --seed <long> random seed, default System.nanoTime()");
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println(" --localNode <string> local database client node");
out.println(" --remoteNode <string> remote database client node");
out.println(" --remotePort <int> remote database port");
out.println();
out.println("Examples:");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\"");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\"");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false");
}
static Options parseOptions(String[] args) {
String prompt = null;
String systemPrompt = null;
float temperature = 0.1f;
float topp = 0.95f;
Path modelPath = null;
long seed = System.nanoTime();
// Keep max context length small for low-memory devices.
int maxTokens = DEFAULT_MAX_TOKENS;
boolean interactive = false;
boolean stream = true;
boolean echo = false;
String localNode = null;
String remoteNode = null;
int remotePort = 0;
for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
}
default -> {
String nextArg;
if (optionName.contains("=")) {
String[] parts = optionName.split("=", 2);
optionName = parts[0];
nextArg = parts[1];
} else {
require(i + 1 < args.length, "Missing argument for option %s", optionName);
nextArg = args[i + 1];
i += 1; // skip arg
}
switch (optionName) {
case "--prompt", "-p" -> prompt = nextArg;
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
case "--top-p" -> topp = Float.parseFloat(nextArg);
case "--model", "-m" -> modelPath = Paths.get(nextArg);
case "--seed", "-s" -> seed = Long.parseLong(nextArg);
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
default -> require(false, "Unknown option: %s", optionName);
}
}
}
}
return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo, localNode, remoteNode, remotePort);
}
}
public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
if(model == null)
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed());
if (options.interactive()) {
runInteractive(model, sampler, options);
} else {
runInstructOnce(model, sampler, options);
}
}
}
final class GGUF {
private static final int GGUF_MAGIC = 0x46554747;
private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2
private static final List<Integer> SUPPORTED_GGUF_VERSIONS = List.of(2, 3);
private int magic;
private int version;
private int tensorCount; // uint64_t
private int alignment;
private int metadata_kv_count; // uint64_t
private Map<String, Object> metadata;
public Map<String, GGUFTensorInfo> getTensorInfos() {
return tensorInfos;
}
private Map<String, GGUFTensorInfo> tensorInfos;
private long tensorDataOffset;
public long getTensorDataOffset() {
return tensorDataOffset;
}
public Map<String, Object> getMetadata() {
return metadata;
}
private final ByteBuffer BB_1 = ByteBuffer.allocate(Byte.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_2 = ByteBuffer.allocate(Short.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_4 = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer BB_8 = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN);
public static GGUF loadModel(Path modelPath) throws IOException {
try (FileChannel fileChannel = FileChannel.open(modelPath);
var ignored = Timer.log("Parse " + modelPath)) {
GGUF gguf = new GGUF();
gguf.loadModelImpl(fileChannel);
return gguf;
}
}
enum MetadataValueType {
// The value is a 8-bit unsigned integer.
UINT8(1),
// The value is a 8-bit signed integer.
INT8(1),
// The value is a 16-bit unsigned little-endian integer.
UINT16(2),
// The value is a 16-bit signed little-endian integer.
INT16(2),
// The value is a 32-bit unsigned little-endian integer.
UINT32(4),
// The value is a 32-bit signed little-endian integer.
INT32(4),
// The value is a 32-bit IEEE754 floating point number.
FLOAT32(4),
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
BOOL(1),
// The value is a UTF-8 non-null-terminated string, with length prepended.
STRING(-8),
// The value is an array of other values, with the length and type prepended.
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
ARRAY(-8),
// The value is a 64-bit unsigned little-endian integer.
UINT64(8),
// The value is a 64-bit signed little-endian integer.
INT64(8),
// The value is a 64-bit IEEE754 floating point number.
FLOAT64(8);
private final int byteSize;
MetadataValueType(int byteSize) {
this.byteSize = byteSize;
}
private static final MetadataValueType[] VALUES = values();
public static MetadataValueType fromIndex(int index) {
return VALUES[index];
}
public int byteSize() {
return byteSize;
}
}
private void loadModelImpl(FileChannel fileChannel) throws IOException {
// The header of the file.
readHeader(fileChannel); // gguf_header_t header;
// Tensor infos, which can be used to locate the tensor data.
// gguf_tensor_info_t tensor_infos[header.tensor_count];
this.tensorInfos = HashMap.newHashMap(tensorCount);
for (int i = 0; i < tensorCount; ++i) {
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
assert !tensorInfos.containsKey(ti.name);
tensorInfos.put(ti.name, ti);
}
// Padding to the nearest multiple of `ALIGNMENT`.
// uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)];
//long _padding = -fileChannel.position() & (ALIGNMENT - 1);
long _padding = getAlignment() - (fileChannel.position() % getAlignment());
fileChannel.position(fileChannel.position() + _padding);
// Tensor data.
//
// This is arbitrary binary data corresponding to the weights of the model. This data should be close
// or identical to the data in the original model file, but may be different due to quantization or
// other optimizations for inference. Any such deviations should be recorded in the metadata or as
// part of the architecture definition.
//
// Each tensor's data must be stored within this array, and located through its `tensor_infos` entry.
// The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors
// should be padded to `ALIGNMENT` bytes.
// uint8_t tensor_data[];
this.tensorDataOffset = fileChannel.position();
}
public static Map<String, GGMLTensorEntry> loadTensors(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
Arena arena = Arena.ofAuto();
MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena);
Map<String, GGMLTensorEntry> tensorEntries = HashMap.newHashMap(tensorInfos.size());
for (Map.Entry<String, GGUFTensorInfo> entry : tensorInfos.entrySet()) {
GGUFTensorInfo ti = entry.getValue();
int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
System.out.println("Tensor:"+entry.getKey()+"="+ti.name+" offset:"+ti.offset+" dims:"+Arrays.toString(ti.dimensions)+" number elems:"+numberOfElements+" size:"+sizeInBytes);
MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes);
tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
}
return tensorEntries;
}
public record GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) {
}
private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
int ggmlTypeId = readInt(fileChannel); // ggml_type type;
return GGMLType.fromId(ggmlTypeId);
}
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
// The name of the tensor. It is a standard GGUF string, with the caveat that
// it must be at most 64 bytes long.
String name = readString(fileChannel); // gguf_string_t name;
assert name.length() <= 64;
// The number of dimensions in the tensor.
// Currently at most 4, but this may change in the future.
int n_dimensions = readInt(fileChannel); // uint32_t n_dimensions;
assert n_dimensions <= 4;
// The dimensions of the tensor.
int[] dimensions = new int[n_dimensions]; // uint64_t dimensions[n_dimensions];
for (int i = 0; i < n_dimensions; ++i) {
dimensions[i] = Math.toIntExact(readLong(fileChannel));
}
// The type of the tensor.
GGMLType ggmlType = readGGMLType(fileChannel); // ggml_type type;
// The offset of the tensor's data in this file in bytes.
// This offset is relative to `tensor_data`, not to the start
// of the file, to make it easier for writers to write the file.
// Readers should consider exposing this offset relative to the
// file to make it easier to read the data.
// Must be a multiple of `ALIGNMENT`.
long offset = readLong(fileChannel); // uint64_t offset;
assert offset % getAlignment() == 0;
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
}
private String readString(FileChannel fileChannel) throws IOException {
// A string in GGUF.
// The length of the string, in bytes.
int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len;
// The string as a UTF-8 non-null-terminated string.
byte[] bytes = new byte[len]; // char string[len];
int bytesRead = fileChannel.read(ByteBuffer.wrap(bytes));
assert len == bytesRead;
return new String(bytes, StandardCharsets.UTF_8);
}
private Pair<String, Object> readKeyValuePair(FileChannel fileChannel) throws IOException {
// The key of the metadata. It is a standard GGUF string, with the following caveats:
// - It must be a valid ASCII string.
// - It must be a hierarchical key, where each segment is `lower_snake_case` and separated by a `.`.
// - It must be at most 2^16-1/65535 bytes long.
// Any keys that do not follow these rules are invalid.
String key = readString(fileChannel); // gguf_string_t key;
assert key.length() < (1 << 16);
assert key.codePoints().allMatch(cp -> ('a' <= cp && cp <= 'z') || ('0' <= cp && cp <= '9') || cp == '_' || cp == '.');
Object value = readMetadataValue(fileChannel);
return new Pair<>(key, value);
}
private Object readMetadataValue(FileChannel fileChannel) throws IOException {
// The type of the value.
// Must be one of the `gguf_metadata_value_type` values.
MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type value_type;
// The value.
return readMetadataValueOfType(value_type, fileChannel); // gguf_metadata_value_t value;
}
void readHeader(FileChannel fileChannel) throws IOException {
// Magic number to announce that this is a GGUF file.
// Must be `GGUF` at the byte level: `0x47` `0x47` `0x55` `0x46`.
// Your executor might do little-endian byte order, so it might be
// check for 0x46554747 and letting the endianness cancel out.
// Consider being *very* explicit about the byte order here.
this.magic = readInt(fileChannel); // uint32_t magic;
if (magic != GGUF_MAGIC) {
throw new IllegalArgumentException("unsupported header.magic " + magic);
}
// The version of the format implemented.
// Must be `3` for version described in this spec.
//
// This version should only be increased for structural changes to the format.
// Changes that do not affect the structure of the file should instead update the metadata
// to signify the change.
this.version = readInt(fileChannel); // uint32_t version;
if (!SUPPORTED_GGUF_VERSIONS.contains(version)) {
throw new IllegalArgumentException("unsupported header.version " + version);
}
// The number of tensors in the file.
// This is explicit, instead of being included in the metadata, to ensure it is always present
// for loading the tensors.
this.tensorCount = Math.toIntExact(readLong(fileChannel)); // uint64_t tensor_count;
// The number of metadata key-value pairs.
this.metadata_kv_count = Math.toIntExact(readLong(fileChannel)); // uint64_t metadata_kv_count;
// The metadata key-value pairs.
// gguf_metadata_kv_t metadata_kv[metadata_kv_count];
this.metadata = HashMap.newHashMap(metadata_kv_count);
for (int i = 0; i < metadata_kv_count; ++i) {
Pair<String, Object> keyValue = readKeyValuePair(fileChannel);
assert !metadata.containsKey(keyValue.first());
metadata.put(keyValue.first(), keyValue.second());
}
}
private Object readArray(FileChannel fileChannel) throws IOException {
// Any value type is valid, including arrays.
MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type type;
// Number of elements, not bytes
int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len;
// The array of values.
// gguf_metadata_value_t array[len];
switch (value_type) {
case UINT8, INT8 -> {
byte[] bytes = new byte[len];
for (int i = 0; i < len; ++i) {
bytes[i] = readByte(fileChannel);
}
return bytes;
}
case UINT16, INT16 -> {
short[] shorts = new short[len];
for (int i = 0; i < len; ++i) {
shorts[i] = readShort(fileChannel);
}
return shorts;
}
case UINT32, INT32 -> {
int[] ints = new int[len];
for (int i = 0; i < len; ++i) {
ints[i] = readInt(fileChannel);
}
return ints;
}
case FLOAT32 -> {
float[] floats = new float[len];
for (int i = 0; i < len; ++i) {
floats[i] = readFloat(fileChannel);
}
return floats;
}
case BOOL -> {
boolean[] booleans = new boolean[len];
for (int i = 0; i < len; ++i) {
booleans[i] = readBoolean(fileChannel);
}
return booleans;
}
case STRING -> {
String[] strings = new String[len];
for (int i = 0; i < len; ++i) {
strings[i] = readString(fileChannel);
}
return strings;
}
case ARRAY -> {
Object[] arrays = new Object[len];
for (int i = 0; i < len; ++i) {
arrays[i] = readArray(fileChannel);
}
return arrays;
}
default -> throw new UnsupportedOperationException("read array of " + value_type);
}
}
private Object readMetadataValueOfType(MetadataValueType valueType, FileChannel fileChannel) throws IOException {
return switch (valueType) {
case UINT8, INT8 -> readByte(fileChannel);
case UINT16, INT16 -> readShort(fileChannel);
case UINT32, INT32 -> readInt(fileChannel);
case FLOAT32 -> readFloat(fileChannel);
case UINT64, INT64 -> readLong(fileChannel);
case FLOAT64 -> readDouble(fileChannel);
case BOOL -> readBoolean(fileChannel);
case STRING -> readString(fileChannel);
case ARRAY -> readArray(fileChannel);
};
}
private byte readByte(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_1);
assert bytesRead == 1;
return BB_1.clear().get(0);
}
private boolean readBoolean(FileChannel fileChannel) throws IOException {
return readByte(fileChannel) != 0;
}
private short readShort(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_2);
assert bytesRead == 2;
return BB_2.clear().getShort(0);
}
private int readInt(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_4);
assert bytesRead == 4;
return BB_4.clear().getInt(0);
}
private long readLong(FileChannel fileChannel) throws IOException {
int bytesRead = fileChannel.read(BB_8);
assert bytesRead == 8;
return BB_8.clear().getLong(0);
}
private float readFloat(FileChannel fileChannel) throws IOException {
return Float.intBitsToFloat(readInt(fileChannel));
}
private double readDouble(FileChannel fileChannel) throws IOException {
return Double.longBitsToDouble(readLong(fileChannel));
}
private MetadataValueType readMetadataValueType(FileChannel fileChannel) throws IOException {
int index = readInt(fileChannel);
return MetadataValueType.fromIndex(index);
}
public int getAlignment() {
if (alignment != 0) {
return alignment;
}
alignment = (int) metadata.getOrDefault("general.alignment", DEFAULT_ALIGNMENT);
assert Integer.bitCount(alignment) == 1 : "alignment must be a power of two";
return alignment;
}
}
interface Timer extends AutoCloseable {
@Override
void close(); // no Exception
static Timer log(String label) {
return log(label, TimeUnit.MILLISECONDS);
}
static Timer log(String label, TimeUnit timeUnit) {
return new Timer() {
final long startNanos = System.nanoTime();
@Override
public void close() {
long elapsedNanos = System.nanoTime() - startNanos;
System.err.println(label + ": "
+ timeUnit.convert(elapsedNanos, TimeUnit.NANOSECONDS) + " "
+ timeUnit.toChronoUnit().name().toLowerCase());
}
};
}
}
/**
* Load model, get GGUF metadata, load vocabulary, create tokenizer, create config, if loadWeights - load tensors, load weights
* create Llama with config, tokenizer, weights
*/
final class ModelLoader {
static final String TOKENIZER_GPT2_MODEL = "gpt2"; // Llama3 uses gpt2!
static final String TOKENIZER_LLAMA_MODEL = "llama"; // non Llama uses llama!
public static String model = "gpt2"; // default for Llama models!
public static String name = null; // Name is based solely on name of model, they all seem to have their own ChatFormat not based on model
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
private static Vocabulary loadVocabulary(Map<String, Object> metadata) {
model = (String) metadata.get("tokenizer.ggml.model");
name = (String) metadata.get("general.name");
if(name.toLowerCase().contains("llama")) // Meta Llama etc. etc.
name = "llama";
else
if(name.toLowerCase().contains("mistral")) //models--mistralai etc. etc.
name="mistral";
else
if(name.toLowerCase().contains("qwen"))
name="qwen";
String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
if(TOKENIZER_LLAMA_MODEL.equals(model)) {
float[] scores = (float[]) metadata.get("tokenizer.ggml.scores");
return new Vocabulary(tokens, scores);
} else {
if(TOKENIZER_GPT2_MODEL.equals(model)) {
return new Vocabulary(tokens, null);
} else {
throw new IllegalArgumentException("expected " + TOKENIZER_GPT2_MODEL + " or "+ TOKENIZER_LLAMA_MODEL+ " but found " + model);
}
}
}
public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
GGUF gguf = GGUF.loadModel(ggufPath);
FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
return loadModel(fileChannel, gguf, contextLength, loadWeights);
}
public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException {
try (var ignored = Timer.log("Load model")) {
Map<String, Object> metadata = gguf.getMetadata();
System.out.println("GGUF metadata:\r\n"+metadata);
Vocabulary vocabulary = loadVocabulary(metadata);
TokenizerInterface tokenizer;
Llama.Configuration config;
Llama.Weights weights = null;
String arch = (String) metadata.get("general.architecture");
if(ModelLoader.name.equals("mistral")) {
tokenizer = createLlamaTokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadLlamaWeights(tensorEntries, config);
}
} else {
if(ModelLoader.name.equals("llama")) {
tokenizer = createGPT2Tokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadGPT2Weights(tensorEntries, config);
}
} else {
if(ModelLoader.name.equals("qwen")) {
tokenizer = createQwen2Tokenizer(metadata, vocabulary);
config = createConfig(arch, metadata, vocabulary, contextLength);
if (loadWeights) {
// loadTensors corresponds to getTensorEntries in old version
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadQwenWeights(tensorEntries, config);
}
} else {
throw new IllegalArgumentException("expected metadata general.name containing mistral, llama, or qwen but found "+ModelLoader.name);
}
}
}
return new Llama(config, tokenizer, weights);
}
}
static Llama.Configuration createConfig(String arch, Map<String, Object> metadata, Vocabulary vocabulary, int contextLength) {
Llama.Configuration config = new Llama.Configuration(
(int) metadata.get(arch+".embedding_length"),
(int) metadata.get(arch+".feed_forward_length"),
(int) metadata.get(arch+".block_count"),
(int) metadata.get(arch+".attention.head_count"),
metadata.containsKey(arch+".attention.head_count_kv")
? (int) metadata.get(arch+".attention.head_count_kv")
: (int) metadata.get(arch+".attention.head_count"),
vocabulary.size(),
(int) metadata.get(arch+".context_length"),
(float) metadata.getOrDefault(arch+".attention.layer_norm_rms_epsilon", 1e-5f),
(float) metadata.getOrDefault(arch+".rope.freq_base", 10000f)
).withContextLength(contextLength);
return config;
}
/**
* Called from AOT.tryUsePreloaded and ModelLoader.loadModel
* @param tensorEntries
* @param config
* @return
*/
static Llama.Weights loadGPT2Weights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
float scaleFactor = 8;
float loFreqFactor = 1;
float hiFreqFactor = 3;
int oldContextLength = 8192;
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta,
ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
Llama.Weights qw = new Llama.Weights(
loadQuantized(tokenEmbeddings),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
// If "output.weight" is not present then the embedding weights are tied/shared with the decoder.
// This is commonly referred as "tie word embeddings".
loadQuantized(tensorEntries.getOrDefault("output.weight", tokenEmbeddings))
);
return qw;
}
static Llama.Weights loadLlamaWeights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
Llama.Weights qw = new Llama.Weights(
loadQuantized(tensorEntries.get("token_embd.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
loadQuantized(tensorEntries.get("output.weight"))
);
return qw;
}
static Llama.Weights loadQwenWeights(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta);
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
Llama.Weights qw = new Llama.Weights(
loadQuantized(tensorEntries.get("token_embd.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
toFloatBuffer(tensorEntries.get("output_norm.weight")),
FloatBuffer.wrap(ropeFreqsReal),
FloatBuffer.wrap(ropeFreqsImag),
loadQuantized(tensorEntries.get("output.weight"))
);
return qw;
}
private final static String QWEN2_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
private static Tokenizer createQwen2Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines)
.map(line -> line.split(" "))
.map(parts ->
new Pair<>(
vocabulary.getIndex(parts[0]).orElseThrow(),
vocabulary.getIndex(parts[1]).orElseThrow())
).toList();
int allTokens = vocabulary.size();
int baseTokens = vocabulary.getIndex("<|endoftext|>").orElseThrow(); // assume all tokens after the base ones are special.
int reservedSpecialTokens = allTokens - baseTokens;
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
i -> specialTokensList.get(i),
i -> baseTokens + i)
);
return new Tokenizer(vocabulary, merges, QWEN2_PATTERN, specialTokens, tokenTypes);
}
private static Tokenizer createGPT2Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines)
.map(line -> line.split(" "))
.map(parts ->
new Pair<>(
vocabulary.getIndex(parts[0]).orElseThrow(),
vocabulary.getIndex(parts[1]).orElseThrow())
).toList();
int allTokens = vocabulary.size();
int baseTokens = 128000; // assume all tokens after the base ones are special.
int reservedSpecialTokens = allTokens - baseTokens;
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
i -> specialTokensList.get(i),
i -> baseTokens + i)
);
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
}
private static MistralTokenizer createLlamaTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
Map<String, Integer> specialTokens =
IntStream.range(0, specialTokensList.size())
.boxed()
.collect(Collectors.toMap(
t -> vocabulary.get(t),
t -> t)
);
return new MistralTokenizer(vocabulary, null, specialTokens, tokenTypes);
}
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatTensor[] array = new FloatTensor[size];
for (int i = 0; i < size; i++) {
array[i] = loadQuantized(getTensorEntry.apply(i));
}
return array;
}
public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatBuffer[] array = new FloatBuffer[size];
for (int i = 0; i < size; i++) {
array[i] = toFloatBuffer(getTensorEntry.apply(i));
}
return array;
}
public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
GGMLType ggmlType = tensorEntry.ggmlType();
return switch (ggmlType) {
case F32 -> tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
default -> throw new UnsupportedOperationException("Conversion to " + ggmlType);
};
}
}
record Llama(Configuration configuration, TokenizerInterface tokenizer, Weights weights) {
public State createNewState(int batchsize, int beginOfText) {
State state = new State(configuration(), batchsize);
state.latestToken = beginOfText; // was tokenizer.getSpecialTokens().get("<|begin_of_text|>");, now we get from ChatFormat.beginOfText() which does the same
return state;
}
public static final class Configuration {
public final int dim; // transformer dimension
public final int hiddenDim; // for ffn layers
public final int numberOfLayers; // number of layers
public final int numberOfHeads; // number of query heads
public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery)
public final int vocabularySize; // vocabulary size, usually 256 (byte-level)
public final int contextLength; // max sequence length
public final float rmsNormEps;
public final float ropeTheta;
public final int headSize;
Configuration withContextLength(int newContextLength) {
if (newContextLength < 0) {
return this; // no change
}
return new Configuration(this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads, this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.rmsNormEps, this.ropeTheta);
}
public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) {
this.dim = dim;
this.hiddenDim = hiddenDim;
this.numberOfLayers = numberOfLayers;
this.numberOfHeads = numberOfHeads;
this.numberOfKeyValueHeads = numberOfKeyValueHeads;
this.vocabularySize = vocabularySize;
this.contextLength = contextLength;
this.rmsNormEps = rmsNormEps;
this.ropeTheta = ropeTheta;
this.headSize = dim / numberOfHeads;
}
}
public static final class Weights {
// token embedding table
public final FloatTensor token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights
// weights for matmuls
public final FloatTensor[] wq; // (layer, n_heads * head_size)
public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
// next 3: qwen - Groff from Qwen2.java
public FloatTensor[] q_bias = null; // (layer, dim)
public FloatTensor[] k_bias = null; // (layer, kv_dim)
public FloatTensor[] v_bias = null; // (layer, kv_dim)
public final FloatBuffer[] rms_ffn_weight; // (layer, dim)
// weights for ffn
public final FloatTensor[] w1; // (layer, hidden_dim, dim)
public final FloatTensor[] w2; // (layer, dim, hidden_dim)
public final FloatTensor[] w3; // (layer, hidden_dim, dim)
// public final rmsnorm
public final FloatBuffer rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
public final FloatBuffer freq_cis_real; // (seq_len, head_size/2)
public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
public final FloatTensor wcls; // (vocab_size, dim)
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
this.token_embedding_table = token_embedding_table;
this.rms_att_weight = rms_att_weight;
this.wq = wq;
this.wk = wk;
this.wv = wv;
this.wo = wo;
this.rms_ffn_weight = rms_ffn_weight;
this.w1 = w1;
this.w2 = w2;
this.w3 = w3;
this.rms_final_weight = rms_final_weight;
this.freq_cis_real = freq_cis_real;
this.freq_cis_imag = freq_cis_imag;
this.wcls = wcls;
}
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] q, FloatTensor[] k, FloatTensor[] v, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
this.token_embedding_table = token_embedding_table;
this.rms_att_weight = rms_att_weight;
this.wq = wq;
this.wk = wk;
this.wv = wv;
this.wo = wo;
this.q_bias = q;
this.k_bias = k;
this.v_bias = v;
this.rms_ffn_weight = rms_ffn_weight;
this.w1 = w1;
this.w2 = w2;
this.w3 = w3;
this.rms_final_weight = rms_final_weight;
this.freq_cis_real = freq_cis_real;
this.freq_cis_imag = freq_cis_imag;
this.wcls = wcls;
}
}
public static final class State {
// current wave of activations
public final int batchsize;
public final FloatTensor[] x; // activation at current time stamp (dim,)
public final FloatTensor[] xb; // same, but inside a residual branch (dim,)
public final FloatTensor[] xb2; // an additional buffer just for convenience (dim,)
public final FloatTensor[] hb; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] q; // query (dim,)
public final FloatTensor[] k; // key (dim,)
public final FloatTensor[] v; // value (dim,)
public final FloatTensor[] att; // buffer for scores/attention values (n_heads, seq_len)
public final FloatTensor logits; // output logits
// kv cache
public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
/** last index in previous block */
int idxPrevBlock;
public int latestToken;
State(Configuration config, int batchsize) {
this.batchsize = batchsize;
this.x = allocate(batchsize, config.dim);
this.xb = allocate(batchsize, config.dim);
this.xb2 = allocate(batchsize, config.dim);
this.hb = allocate(batchsize, config.hiddenDim);
this.hb2 = allocate(batchsize, config.hiddenDim);
this.q = allocate(batchsize, config.dim);
this.k = allocate(batchsize, config.dim);
this.v = allocate(batchsize, config.dim);
this.att = allocate(batchsize, config.numberOfHeads, config.contextLength);
idxPrevBlock = -1;
this.logits = ArrayFloatTensor.allocate(config.vocabularySize);
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
}
}
static FloatTensor[] allocate(int numTokens, int... dims) {
return IntStream.range(0, numTokens)
.mapToObj(i -> ArrayFloatTensor.allocate(dims))
.toArray(FloatTensor[]::new);
}
static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
// calculate sum of squares
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
// normalize and scale
final float finalss = ss; // for the lambda
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
}
static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
// a few convenience variables
Configuration config = model.configuration();
Weights weights = model.weights();
int dim = config.dim;
int headSize = config.headSize;
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);
final int nTokens = tokens.length;
// copy the token embedding into x
Parallel.parallelFor(0, nTokens, t ->
weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim)
);
// forward all the layers
for (int l = 0; l < config.numberOfLayers; l++) {
// attention rmsnorm
// rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps);
final int curLayer = l;
Parallel.parallelFor(0, nTokens, t ->
rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps)
);
// qkv matmuls for this position
weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim);
weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
Parallel.parallelFor(0, nTokens, t -> {
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = vi == 0 ? state.q[t] : state.k[t]; // the vector to rotate (query or key)
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}
});
// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
Parallel.parallelFor(0, nTokens, t -> {
state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim);
state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
});
// If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
if (!computeLogits && curLayer == config.numberOfLayers - 1) {
state.idxPrevBlock = nTokens - 1;
return null;
}
// multihead attention. iterate over all heads
Parallel.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> {
int token = (int) (ht / config.numberOfHeads);
int h = (int) (ht % config.numberOfHeads);
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength;
// iterate over all timesteps, including the current one
for (int t = 0; t <= position + token; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att[token].setFloat(attOffset + t, score);
}
// softmax the scores to get attention weights, from 0..position inclusively
state.att[token].softmaxInPlace(attOffset, position + token + 1);
// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb[token].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= position + token; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att[token].getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});
// final matmul to get the output of the attention
weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim);
// residual connection back into x
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb2[t]);
});
// ffn rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps);
});
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim);
weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim);
// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
});
// elementwise multiply with w3(x)
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].multiplyInPlace(state.hb2[t]);
});
// final matmul to get the output of the ffn
weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim);
// residual connection
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb[t]);
});
}
// final rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps);
});
if(false) {
SuperBit sb = null;
try (Timer timer = Timer.log("SuperBits:"+state.x[nTokens-1].size())) {
//sb = new SuperBit(state.x[nTokens-1].size());
sb = new SuperBit(100);
}
try (Timer timer = Timer.log("Signature")) {
sb.signature(state.x[nTokens-1]);
}
}
// classifier into logits
weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim);
state.idxPrevBlock = nTokens - 1;
return state.logits;
}
static FloatTensor forwardQwen(Llama model, State state, int token, int position) {
// a few convenience variables
Llama.Configuration config = model.configuration();
Llama.Weights weights = model.weights();
int dim = config.dim;
int headSize = config.headSize;
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);
// copy the token embedding into x
weights.token_embedding_table.copyTo(token * dim, state.x[0], 0, dim);
// forward all the layers
for (int l = 0; l < config.numberOfLayers; l++) {
// attention rmsnorm
rmsnorm(state.xb[0], state.x[0], weights.rms_att_weight[l], dim, config.rmsNormEps);
// qkv matmuls for this position
weights.wq[l].matmul(state.xb[0], state.q[0], dim, dim);
if (weights.q_bias != null && weights.q_bias[l] != null) {
//state.q[0].addInPlace(weights.q_bias[l]);
System.out.println("state:"+state.q[0].size());
state.q[0].verify();
System.out.println("weights:"+weights.q_bias[l].size());
weights.q_bias[l].verify();
state.q[0].addInPlace(weights.q_bias[l]);
}
weights.wk[l].matmul(state.xb[0], state.k[0], kvDim, dim);
if (weights.k_bias != null && weights.k_bias[l] != null) {
state.k[0].addInPlace(weights.k_bias[l]);
}
weights.wv[l].matmul(state.xb[0], state.v[0], kvDim, dim);
if (weights.v_bias != null && weights.v_bias[l] != null) {
state.v[0].addInPlace(weights.v_bias[l]);
}
// RoPE relative positional encoding: complex-valued rotate q and k in each head
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
for (int h = 0; h < config.numberOfHeads; ++h) {
int rotn = h < config.numberOfKeyValueHeads ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
int poffset = h * headSize;
for (int i0 = 0; i0 < headSize; i0 += 2) {
int ic = i0 / 2;
float fcr = weights.freq_cis_real.get(position * (headSize / 2) + ic);
float fci = weights.freq_cis_imag.get(position * (headSize / 2) + ic);
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q[0] : state.k[0]; // the vector to rotate (query or key)
float v0 = vec.getFloat(poffset + ic);
float v1 = vec.getFloat(poffset + ic + headSize/2);
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
}
}
}
// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
state.k[0].copyTo(0, state.keyCache[l], position * kvDim, kvDim);
state.v[0].copyTo(0, state.valueCache[l], position * kvDim, kvDim);
int curLayer = l;
// multihead attention. iterate over all heads
Parallel.parallelFor(0, config.numberOfHeads, h -> {
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength;
// iterate over all timesteps, including the current one
for (int t = 0; t <= position; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q[0].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att[0].setFloat(attOffset + t, score);
}
// softmax the scores to get attention weights, from 0..position inclusively
state.att[0].softmaxInPlace(attOffset, position + 1);
// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb[0].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= position; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att[0].getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb[0].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});
// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb[0], state.xb2[0], dim, dim);
// residual connection back into x
state.x[0].addInPlace(state.xb2[0]);
// ffn rmsnorm
rmsnorm(state.xb[0], state.x[0], weights.rms_ffn_weight[l], dim, config.rmsNormEps);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(state.xb[0], state.hb[0], config.hiddenDim, dim);
weights.w3[l].matmul(state.xb[0], state.hb2[0], config.hiddenDim, dim);
// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
state.hb[0].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
// elementwise multiply with w3(x)
state.hb[0].multiplyInPlace(state.hb2[0]);
// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb[0], state.xb[0], dim, config.hiddenDim);
// residual connection
state.x[0].addInPlace(state.xb[0]);
}
// final rmsnorm
rmsnorm(state.x[0], state.x[0], weights.rms_final_weight, dim, config.rmsNormEps);
// classifier into logits
weights.wcls.matmul(state.x[0], state.logits, config.vocabularySize, dim);
return state.logits;
}
/**
* LLM generation entry point, ingest prompt tokens and generates new tokens.
*
* <p>
* All prompt tokens are ingested first, then inference starts, until a stop token is found.
* The returned tokens only include generated/inferred tokens.
*
* @param model model to run inference (including weights, configuration, tokenizer ...)
* @param state state of the model e.g. key/value caches ... this is mutated by this call
* @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context.
* @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context
* @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
* @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length}
* if this value is negative or greater than {@link Configuration#contextLength context length}
* @param sampler {@link Sampler strategy} used to select tokens
* @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
*/
public static List<Integer> generateTokens(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
long startGen = 0;
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
maxTokens = model.configuration().contextLength;
}
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
int token = state.latestToken; // BOS?
int nextToken;
int promptIndex = 0;
for (int position = startPosition; position < maxTokens; ++position) {
if (promptIndex < promptTokens.size()) {
final int nTokens = Math.min(maxTokens - position, Math.min(promptTokens.size() - promptIndex, state.batchsize));
final int[] tokens = new int[nTokens];
for (int i = 0; i < nTokens; i++) {
tokens[i] = promptTokens.get(promptIndex + i);
if (echo) {
// log prompt token (different color?)
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(tokens[i]))));
}
}
if (echo) {
System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex, promptTokens.size(), Arrays.toString(tokens));
}
// Only compute logits on the very last batch.
boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
forward(model, state, tokens, position, computeLogits);
position += nTokens - 1; // -1 -> incremented later in the for loop
promptIndex += nTokens;
if (promptIndex < promptTokens.size()) {
continue;
}
startGen = System.nanoTime();
} else {
forward(model, state, new int[]{token}, position, true);
}
nextToken = sampler.sampleToken(state.logits);
if (echo) {
// log inferred token
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
generatedTokens.add(nextToken);
if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}
if (stopTokens.contains(nextToken)) {
break;
}
state.latestToken = token = nextToken;
}
long elapsedNanos = System.nanoTime() - startNanos;
long promptNanos = startGen - startNanos;
long genNanos = elapsedNanos - startGen + startNanos;
System.err.printf("%ncontext: %d/%d prompt: %.2f tokens/s (%d) generation: %.2f tokens/s (%d)%n",
startPosition + promptIndex + generatedTokens.size(), model.configuration().contextLength,
promptTokens.size() / (promptNanos / 1_000_000_000.0), promptTokens.size(),
generatedTokens.size() / (genNanos / 1_000_000_000.0), generatedTokens.size());
return generatedTokens;
}
/**
* Qwen specific calls forwardQwen.
* @param model
* @param state
* @param startPosition
* @param promptTokens
* @param stopTokens
* @param maxTokens
* @param sampler
* @param echo
* @param onTokenGenerated
* @return
*/
public static List<Integer> generateTokensQwen(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
maxTokens = model.configuration().contextLength;
}
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
int token = state.latestToken; // BOS?
int nextToken;
int promptIndex = 0;
for (int position = startPosition; position < maxTokens; ++position) {
forwardQwen(model, state, token, position);
if (promptIndex < promptTokens.size()) {
// Force-pick token from prompt.
nextToken = promptTokens.get(promptIndex++);
if (echo) {
// log prompt token (different color?)
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
} else {
nextToken = sampler.sampleToken(state.logits);
if (echo) {
// log inferred token
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
generatedTokens.add(nextToken);
if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}
if (stopTokens.contains(nextToken)) {
break;
}
}
state.latestToken = token = nextToken;
}
long elapsedNanos = System.nanoTime() - startNanos;
int totalTokens = promptIndex + generatedTokens.size();
System.err.printf("%n%.2f tokens/s (%d)%n", totalTokens / (elapsedNanos / 1_000_000_000.0), totalTokens);
return generatedTokens;
}
}
interface TokenizerInterface {
public Map<String, Integer> getSpecialTokens();
public boolean isSpecialToken(int tokenIndex);
public String decode(List<Integer> tokens);
public List<Integer> encodeAsList(String text);
public int getTokenType(int tokenIndex);
}
/**
* Byte Pair Encoding tokenizer.
* <p>
* Based on <a href="https://github.com/karpathy/minbpe">minbpe</a>, algorithmically follows along the
* <a href="https://github.com/openai/gpt-2/blob/master/src/encoder.py">GPT 2 tokenizer</a>
*/
class Tokenizer implements TokenizerInterface {
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map<Pair<Integer, Integer>, Integer> merges;
private final Map<String, Integer> specialTokens;
private int[] tokenTypes; // qwen2
public String regexPattern() {
if (compiledPattern == null) {
return null;
}
return compiledPattern.pattern();
}
@Override
public Map<String, Integer> getSpecialTokens() {
return specialTokens;
}
@Override
public boolean isSpecialToken(int tokenIndex) {
return specialTokens.containsValue(tokenIndex);
}
@Override
public int getTokenType(int tokenIndex) {
return tokenTypes[tokenIndex];
}
public Tokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens) {
this.vocabulary = vocabulary;
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
this.specialTokens = new HashMap<>(specialTokens);
this.merges = new HashMap<>();
for (Pair<Integer, Integer> pair : merges) {
int firstIndex = pair.first();
int secondIndex = pair.second();
int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow();
this.merges.put(pair, mergeIndex);
}
}
public Tokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens, int[] tokenTypes) {
this(vocabulary, merges, regexPattern, specialTokens);
this.tokenTypes = tokenTypes;
}
private int[] encodeImpl(String text) {
return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
}
/**
* Unlike {@link #encodeOrdinary(String)}, this function handles special tokens.
* allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
* if none_raise, then an error is raised if any special token is encountered in text
* this is the default tiktoken behavior right now as well
* any other behavior is either annoying, or a major footgun.
*/
List<Integer> encode(String text, Set<String> allowedSpecial) {
// decode the user desire w.r.t. handling of special tokens
Set<String> special = allowedSpecial;
assert getSpecialTokens().keySet().containsAll(special);
if (special.isEmpty()) {
// shortcut: if no special tokens, just use the ordinary encoding
return encodeOrdinary(text);
}
// otherwise, we have to be careful with potential special tokens in text
// we handle special tokens by splitting the text
// based on the occurrence of any exact match with any of the special tokens
// we can use re.split for this. note that surrounding the pattern with ()
// makes it into a capturing group, so the special tokens will be included
String specialPattern = special
.stream()
.map(Pattern::quote)
.collect(Collectors.joining("|", "(", ")"));
String[] specialChunks = text.split(specialPattern);
// now all the special characters are separated from the rest of the text
// all chunks of text are encoded separately, then results are joined
List<Integer> ids = new ArrayList<>();
for (String part : specialChunks) {
if (special.contains(part)) {
// this is a special token, encode it separately as a special case
ids.add(getSpecialTokens().get(part));
} else {
// this is an ordinary sequence, encode it normally
ids.addAll(encodeOrdinary(part));
}
}
return ids;
}
private static List<String> findAll(Pattern pattern, String text) {
List<String> allMatches = new ArrayList<>();
Matcher matcher = pattern.matcher(text);
while (matcher.find()) {
allMatches.add(matcher.group());
}
return allMatches;
}
/**
* Encoding that ignores any special tokens.
*/
public List<Integer> encodeOrdinary(String text) {
// split text into chunks of text by categories defined in regex pattern
List<String> textChunks = findAll(compiledPattern, text);
// all chunks of text are encoded separately, then results are joined
List<Integer> ids = new ArrayList<>();
for (String chunk : textChunks) {
List<Integer> chunkIds = encodeChunk(chunk);
ids.addAll(chunkIds);
}
return ids;
}
private Map<Pair<Integer, Integer>, Integer> getStats(List<Integer> ids) {
Map<Pair<Integer, Integer>, Integer> map = new HashMap<>();
for (int i = 0; i + 1 < ids.size(); i++) {
Pair<Integer, Integer> key = new Pair<>(ids.get(i), ids.get(i + 1));
map.put(key, map.getOrDefault(key, 0) + 1);
}
return map;
}
private List<Integer> encodeChunk(String chunk) {
// return the token ids
// let's begin. first, convert all bytes to integers in range 0..255
List<Integer> ids = new ArrayList<>();
for (int b : chunk.toCharArray()) {
int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow();
ids.add(tokenIndex);
}
while (ids.size() >= 2) {
// find the pair with the lowest merge index
Map<Pair<Integer, Integer>, Integer> stats = getStats(ids);
Pair<Integer, Integer> pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow();
// subtle: if there are no more merges available, the key will
// result in an inf for every single pair, and the min will be
// just the first pair in the list, arbitrarily
// we can detect this terminating case by a membership check
if (!this.merges.containsKey(pair)) {
break; // nothing else can be merged anymore
}
// otherwise let's merge the best pair (lowest merge index)
int idx = this.merges.get(pair);
ids = merge(ids, pair, idx);
}
return ids;
}
private static List<Integer> merge(List<Integer> ids, Pair<Integer, Integer> pair, int idx) {
List<Integer> newids = new ArrayList<>();
int i = 0;
while (i < ids.size()) {
// if not at the very last position AND the pair matches, replace it
if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
newids.add(idx);
i += 2;
} else {
newids.add(ids.get(i));
i += 1;
}
}
return newids;
}
public String decodeImpl(List<Integer> tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
String tokenString = vocabulary.get(token);
sb.append(tokenString);
}
return sb.toString();
}
/**
* Returns list of utf-8 byte and a corresponding list of unicode strings.
* The reversible bpe codes work on unicode strings.
* This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
* When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
* This is a significant percentage of your normal, say, 32K bpe vocab.
* To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
* And avoids mapping to whitespace/control characters the bpe code barfs on.
*/
private static Map<Integer, Integer> bytesToUnicode() {
List<Integer> bs = new ArrayList<>();
IntStream.rangeClosed('!', '~').forEach(bs::add);
IntStream.rangeClosed('¡', '¬').forEach(bs::add);
IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
List<Integer> cs = new ArrayList<>(bs);
int n = 0;
for (int b = 0; b < 256; ++b) {
if (!bs.contains(b)) {
bs.add(b);
cs.add(256 + n);
n += 1;
}
}
// return dict(zip(bs, cs))
return IntStream.range(0, bs.size())
.boxed()
.collect(Collectors.toMap(bs::get, cs::get));
}
static final Map<Integer, Integer> BYTE_ENCODER = bytesToUnicode();
static final Map<Integer, Integer> BYTE_DECODER = BYTE_ENCODER.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
public int[] encode(String text) {
StringBuilder sb = new StringBuilder();
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
for (byte b : bytes) {
sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
}
return encodeImpl(sb.toString());
}
public static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
// which distort the output (e.g. \n or much worse)
// https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
// http://www.unicode.org/reports/tr44/#GC_Values_Table\
StringBuilder chars = new StringBuilder();
for (int cp : codePoints) {
if (Character.getType(cp) == Character.CONTROL && cp != '\n') {
chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape
} else {
chars.appendCodePoint(cp); // this character is ok
}
}
return chars.toString();
}
public static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
@Override
public List<Integer> encodeAsList(String text) {
return Arrays.stream(encode(text)).boxed().toList();
}
@Override
public String decode(List<Integer> tokens) {
String decoded = decodeImpl(tokens);
int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray();
byte[] rawBytes = new byte[decodedBytesAsInts.length];
for (int i = 0; i < decoded.length(); i++) {
rawBytes[i] = (byte) decodedBytesAsInts[i];
}
return new String(rawBytes, StandardCharsets.UTF_8);
}
}
/**
* Wherein Llama models metadata.get("tokenizer.ggml.model") = gpt2
* and Mistral uses metadata.get("tokenizer.ggml.model") = llama.
*/
class MistralTokenizer implements TokenizerInterface {
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map<String, Integer> specialTokens;
private final int[] tokenType;
private final int byte0;
public String regexPattern() {
if (compiledPattern == null) {
return null;
}
return compiledPattern.pattern();
}
@Override
public Map<String, Integer> getSpecialTokens() {
return specialTokens;
}
@Override
public boolean isSpecialToken(int tokenIndex) {
return getTokenType(tokenIndex) != 1;
}
@Override
public int getTokenType(int tokenIndex) {
return tokenType[tokenIndex];
}
public MistralTokenizer(Vocabulary vocabulary, String regexPattern, Map<String, Integer> specialTokens, int[] tokenType) {
this.vocabulary = vocabulary;
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
this.specialTokens = new HashMap<>(specialTokens);
this.tokenType = tokenType;
this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
}
List<Integer> encode(String text) {
return encodeImpl(text.replace(' ', '▁'));
}
private List<Integer> encodeImpl(String text) {
List<Integer> tokens = new ArrayList<>();
// first encode every individual codepoint in the input string
for (int i = 0, cpi; i < text.length(); i += Character.charCount(cpi)) {
cpi = text.codePointAt(i);
String singleCodepoint = Character.toString(cpi);
int id = vocabulary.getIndex(singleCodepoint).orElse(-1);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens.add(id);
} else {
// byte_fallback encoding: just encode each byte as a token
// +byte0 here to skip all the control and special tokens e.g. <unk>, <s>, </s>
// so the individual bytes only start at token <0x00>
for (byte b : singleCodepoint.getBytes(StandardCharsets.UTF_8)) {
tokens.add(Byte.toUnsignedInt(b) + byte0);
}
}
}
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (true) {
float best_score = -1e10f;
int best_id = -1;
int best_idx = -1;
for (int i = 0; i < tokens.size() - 1; ++i) {
// check if we can merge the pair (tokens[i], tokens[i+1])
String str_buffer = vocabulary.get(tokens.get(i)) + vocabulary.get(tokens.get(i + 1));
int id = vocabulary.getIndex(str_buffer).orElse(-1);
if (id != -1 && vocabulary.getScore(id) > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocabulary.getScore(id);
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens.set(best_idx, best_id);
tokens.remove(best_idx + 1);
}
return tokens;
}
@Override
public String decode(List<Integer> tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
String tokenString = vocabulary.get(token);
if (isSpecialToken(token)) {
// some tokens designate raw bytes e.g. '<0x10>'
String prefix = "<0x";
String suffix = ">";
if (tokenString.length() == 6 && tokenString.startsWith(prefix) && tokenString.endsWith(suffix)) {
String code = tokenString.substring(prefix.length(), tokenString.length() - suffix.length());
int cp = Integer.parseInt(code, 16);
tokenString = Character.toString(cp);
}
} else {
tokenString = tokenString.replace('▁', ' ');
}
sb.append(tokenString);
}
return sb.toString();
}
public static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
// which distort the output (e.g. \n or much worse)
// https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
// http://www.unicode.org/reports/tr44/#GC_Values_Table\
StringBuilder chars = new StringBuilder();
for (int cp : codePoints) {
if (Character.getType(cp) == Character.CONTROL && cp != '\n') {
chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape
} else {
chars.appendCodePoint(cp); // this character is ok
}
}
return chars.toString();
}
public static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
public List<Integer> encodeAsList(String text) {
return encode(text);
}
}
final class Parallel {
public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
IntStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
LongStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
}
record Pair<First, Second>(First first, Second second) {
}
record GGMLTensorEntry(MemorySegment mappedFile, String name, GGMLType ggmlType, int[] shape,
MemorySegment memorySegment) {
}
enum GGMLType {
F32(Float.BYTES),
F16(GGMLType.FLOAT16_BYTES),
Q4_0(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_1(2 * GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
UNSUPPORTED_Q4_2(Integer.MAX_VALUE), // support has been removed
UNSUPPORTED_Q4_3(Integer.MAX_VALUE), // support has been removed
Q5_0(Integer.MAX_VALUE),
Q5_1(Integer.MAX_VALUE),
Q8_0(GGMLType.FLOAT16_BYTES + 32 * Byte.BYTES, 32),
Q8_1(32 * Byte.BYTES + 2 * Float.BYTES, 32),
// k-quantizations
Q2_K(Integer.MAX_VALUE),
Q3_K(Integer.MAX_VALUE),
Q4_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 2, GGMLType.QK_K),
Q5_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 8 + GGMLType.QK_K / 2, GGMLType.QK_K),
Q6_K(GGMLType.QK_K / 2 + GGMLType.QK_K / 4 + GGMLType.QK_K / 16 + GGMLType.FLOAT16_BYTES, GGMLType.QK_K),
Q8_K(Integer.MAX_VALUE),
IQ2_XXS(Integer.MAX_VALUE),
IQ2_XS(Integer.MAX_VALUE),
IQ3_XXS(Integer.MAX_VALUE),
IQ1_S(Integer.MAX_VALUE),
IQ4_NL(Integer.MAX_VALUE),
IQ3_S(Integer.MAX_VALUE),
IQ2_S(Integer.MAX_VALUE),
IQ4_XS(Integer.MAX_VALUE),
I8(Byte.BYTES),
I16(Short.BYTES),
I32(Integer.BYTES),
I64(Long.BYTES),
F64(Double.BYTES),
IQ1_M(Integer.MAX_VALUE),
BF16(GGMLType.BFLOAT16_BYTES),
Q4_0_4_4(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_0_4_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
Q4_0_8_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32),
TQ1_0(Integer.MAX_VALUE),
TQ2_0(Integer.MAX_VALUE);
public static final int BFLOAT16_BYTES = 2;
public static final int FLOAT16_BYTES = 2;
private static final GGMLType[] VALUES = values();
private final int typeSize;
private final int blockSize;
public int getTypeSize() {
return typeSize;
}
public int getBlockSize() {
return blockSize;
}
public static GGMLType fromId(int id) {
return VALUES[id];
}
GGMLType(int typeSize) {
this(typeSize, 1);
}
public long byteSizeFor(int numberOfElements) {
long t = numberOfElements * (long) getTypeSize();
assert t % getBlockSize() == 0;
return Math.toIntExact(t / getBlockSize());
}
public static final int QK_K = 256; // or 64?
GGMLType(int typeSize, int blockSize) {
assert blockSize > 0;
assert typeSize > 0;
assert isPowerOf2(blockSize);
this.typeSize = typeSize;
this.blockSize = blockSize;
}
private static boolean isPowerOf2(int n) {
return n > 0 && (n & (n - 1)) == 0;
}
}
/**
* Over-simplified, shapeless, float tensor.
* <p>
* Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory
* e.g. can represent a sequence of quantized floats.
*/
abstract class FloatTensor implements Externalizable, Comparable {
static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize());
static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0;
static short readShort(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_SHORT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static int readInt(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_INT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static float readFloat(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_FLOAT, offset);
//return UNSAFE.getShort(memorySegment.address() + offset);
}
static byte readByte(MemorySegment memorySegment, long offset) {
return memorySegment.get(ValueLayout.JAVA_BYTE, offset);
//return UNSAFE.getByte(memorySegment.address() + offset);
}
// Preferred vector size for the fast multiplication routines.
// (Apple Silicon) NEON only supports up-to 128bit vectors.
static final VectorSpecies<Float> F_SPECIES;
static final VectorSpecies<Integer> I_SPECIES;
static final VectorSpecies<Short> S_SPECIES_HALF;
static {
if (USE_VECTOR_API) {
F_SPECIES = VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class);
I_SPECIES = F_SPECIES.withLanes(int.class);
S_SPECIES_HALF = VectorShape.forBitSize(F_SPECIES.vectorBitSize() / 2).withLanes(short.class);
assert F_SPECIES.length() == S_SPECIES_HALF.length();
} else {
F_SPECIES = null;
I_SPECIES = null;
S_SPECIES_HALF = null;
}
}
abstract int size();
abstract float getFloat(int index);
abstract void setFloat(int index, float value);
abstract FloatVector getFloatVector(VectorSpecies<Float> species, int offset);
abstract GGMLType type();
public static int numberOfElements(int... dimensions) {
assert Arrays.stream(dimensions).allMatch(i -> i > 0);
return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow();
}
static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) {
float result = 0f;
for (int j = 0; j < size; j++) {
result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j);
}
return result;
}
float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
return scalarDot(this, thisOffset, that, thatOffset, size);
}
void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) {
Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1)));
}
void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) {
if (that.length != out.length) {
throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length));
}
Parallel.parallelForLong(0, dim0 * context, ti -> {
int idxArr = (int) (ti / dim0);
int i = (int) (ti % dim0);
out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1));
});
}
@FunctionalInterface
interface AggregateFunction {
float apply(float acc, float value);
}
float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) {
float result = seed;
for (int i = 0; i < size; ++i) {
result = reduce.apply(result, getFloat(thisOffset + i));
}
return result;
}
float sum(int thisOffset, int size) {
return reduce(thisOffset, size, 0f, Float::sum);
}
float max(int thisOffset, int size) {
return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max);
}
void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset));
}
int argmax(int thisOffset, int size) {
assert size > 0;
int maxIndex = thisOffset;
float maxValue = this.getFloat(maxIndex);
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
float f = this.getFloat(i);
if (f > maxValue) {
maxValue = f;
maxIndex = i;
}
}
return maxIndex;
}
int argmax() {
return argmax(0, size());
}
@FunctionalInterface
interface MapFunction {
float apply(float value);
}
@FunctionalInterface
interface MapWithIndexFunction {
float apply(float value, int index);
}
FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) {
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
setFloat(i, mapFunction.apply(getFloat(i)));
}
return this;
}
FloatTensor mapInPlace(MapFunction mapFunction) {
return mapInPlace(0, size(), mapFunction);
}
FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) {
int endOffset = thisOffset + size;
for (int i = thisOffset; i < endOffset; ++i) {
System.out.println("setFloat:"+i+" of size:"+size);
setFloat(i, mapWithIndexFunction.apply(getFloat(i), i));
}
return this;
}
FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor addInPlace(FloatTensor that) {
return addInPlace(0, that, 0, size());
}
FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor multiplyInPlace(FloatTensor that) {
return multiplyInPlace(0, that, 0, size());
}
FloatTensor divideInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, f -> f / value);
}
FloatTensor fillInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, unused -> value);
}
FloatTensor softmaxInPlace(int thisOffset, int size) {
// find max value (for numerical stability)
float maxVal = max(thisOffset, size);
// exp and sum
mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal));
float sum = sum(thisOffset, size);
// normalize
return divideInPlace(thisOffset, size, sum);
}
FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) {
// this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size)
for (int i = 0; i < size; ++i) {
setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i));
}
return this;
}
float cosineSimilarity(FloatTensor a, FloatTensor b) {
float dotProduct = a.dot(0, b, 0, a.size());
DoubleAdder aNormAdder = new DoubleAdder();
DoubleAdder bNormAdder = new DoubleAdder();
Parallel.parallelFor(0, a.size(), t -> {
aNormAdder.add(a.getFloat(t) * a.getFloat(t));
bNormAdder.add(b.getFloat(t) * b.getFloat(t));
});
float aNorm = (float) Math.sqrt(aNormAdder.sum());
float bNorm = (float) Math.sqrt(bNormAdder.sum());
return (dotProduct / (aNorm * bNorm));
}
public void verify() {
System.out.println("size:"+size());
System.out.println("Verified via String of length:"+toString().length());
}
public String toString() {
StringBuilder sb = new StringBuilder("[");
for(int i = 0; i < size(); i++) {
sb.append(getFloat(i));
if(i == (size()-1))
sb.append("]");
else
sb.append(",");
}
return sb.toString();
}
}
/**
* {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format.
* <p>
* This tensor implementation is not compatible with {@link FloatTensor}, but
* {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when
* the second argument implements {@link FloatTensor}.
*/
final class Q4_0FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public Q4_0FloatTensor() {}
public Q4_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q4_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
byte quant;
int modIndex = index % GGMLType.Q4_0.getBlockSize();
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
quant = (byte) (readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + modIndex) & 0x0F);
} else {
quant = (byte) ((readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
}
quant -= 8;
return quant * scale;
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + j to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize();
int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var loBytes = wBytes.and((byte) 0xF).sub((byte) 8);
var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 0));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 0));
val = sum0.add(sum2).fma(wScale, val);
}
case 256 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 0));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var tmp = i == 0 ? loBytes : hiBytes;
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((Q4_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((Q4_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class Q8_0FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public Q8_0FloatTensor() {}
public Q8_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q8_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q8_0.getBlockSize();
int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize();
byte quant = readByte(memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + withinBlockIndex);
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
return quant * scale;
}
public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN);
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + startIndex to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize();
int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).fma(wScale, val);
}
case 256 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + GGMLType.FLOAT16_BYTES + i * ByteVector.SPECIES_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((Q8_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((Q8_0FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class BF16FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public BF16FloatTensor() {}
public BF16FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.BF16;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
return bfloat16ToFloat(readShort(memorySegment, index * GGMLType.BFLOAT16_BYTES));
}
private float bfloat16ToFloat(short bfloat16) {
return Float.intBitsToFloat(bfloat16 << 16);
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(BF16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
assert S_SPECIES_HALF.length() == F_SPECIES.length();
FloatVector val = FloatVector.zero(F_SPECIES);
int upperBound = F_SPECIES.loopBound(size);
for (int i = 0; i < upperBound; i += F_SPECIES.length()) {
FloatVector thatVector = that.getFloatVector(F_SPECIES, thatOffset + i);
ShortVector bfloat16 = ShortVector.fromMemorySegment(S_SPECIES_HALF, thiz.memorySegment, (thisOffset + i) * (long) GGMLType.BFLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
// BFloat16 to Float32 Conversion:
//
// ┌─[15]─┬─[14]───····───[7]─┬─[6]────····────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (7 bits) │ BFloat16 Layout (16 bits)
// └──────┴───────────────────┴────────────────────┘
// │ │ │
// ▼ ▼ ▼
// ┌─[31]─┬─[30]───···───[23]─┬─[22]────···────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (23 bits) │ Float32 Layout (32 bits)
// └──────┴───────────────────┴────────────────────┘
FloatVector thizVector = bfloat16
.castShape(I_SPECIES, 0) // (int) vi
.lanewise(VectorOperators.LSHL, 16) // vi <<= 16
.reinterpretAsFloats(); // Float.intBitsToFloat(vi)
val = thizVector.fma(thatVector, val);
}
float result = val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (upperBound < size) {
result += scalarDot(thiz, thisOffset + upperBound, that, thatOffset + upperBound, size - upperBound);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((BF16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((BF16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class F16FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public F16FloatTensor() {}
public F16FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.F16;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
return Float.float16ToFloat(readShort(memorySegment, index * GGMLType.FLOAT16_BYTES));
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(F16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
assert S_SPECIES_HALF.length() == F_SPECIES.length();
FloatVector val = FloatVector.zero(F_SPECIES);
int upperBound = F_SPECIES.loopBound(size);
for (int i = 0; i < upperBound; i += F_SPECIES.length()) {
FloatVector thatVector = that.getFloatVector(F_SPECIES, thatOffset + i);
ShortVector bits16 = ShortVector.fromMemorySegment(S_SPECIES_HALF, thiz.memorySegment, (thisOffset + i) * (long) GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN);
var bits32 = bits16.castShape(I_SPECIES, 0).reinterpretAsInts(); // (int) bits16
// Does not support infinities nor NaNs, preserves sign, emulate DAZ (denormals-are-zero).
// Expects well-formed float16 values only (e.g. model weights).
// Fast Float16 to Float32 Conversion:
//
// ┌─[15]─┬─[14]───···───[10]─┬─[9]────····────[0]─┐
// │ Sign │ Exponent (5 bits) │ Mantissa (10 bits) │ Float16 Layout (16 bits)
// └──────┴───────────────────┴────────────────────┘
// │ │ │
// ▼ ▼ ▼
// ┌─[31]─┬─[30]───···───[23]─┬─[22]────···────[0]─┐
// │ Sign │ Exponent (8 bits) │ Mantissa (23 bits) │ Float32 Layout (32 bits)
// └──────┴───────────────────┴────────────────────┘
//
// Shifts and adjustments:
// - Sign: float16[15] -> float32[31] (shift 16 bits up)
// - Exponent: float16[10-14] -> float32[23-30] (+ bias adjustment)
// - Mantissa: float16[0-9] -> float32[13-22] (shift 13 bits up)
//
// exp = bits32 & 0x7C00
// zeroExponentMask = exp == 0 ? 0 : ~0
var zeroExponentMask = bits32.and(0x7C00).neg().lanewise(VectorOperators.ASHR, 31); // = (-exp) >> 31
bits32 = bits32.and(0x8000).lanewise(VectorOperators.LSHL, 16) // sign
.or(
// exponent and mantissa combined
bits32.and(0x7FFF).add(0x1C000).lanewise(VectorOperators.LSHL, 13)
.and(zeroExponentMask) // -0, +0 and DAZ (denormals-are-zero)
);
FloatVector thizVector = bits32.reinterpretAsFloats(); // Float.intBitsToFloat(vi)
val = thizVector.fma(thatVector, val);
}
float result = val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (upperBound < size) {
result += scalarDot(thiz, thisOffset + upperBound, that, thatOffset + upperBound, size - upperBound);
}
return result;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((F16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((F16FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class F32FloatTensor extends FloatTensor implements Externalizable, Comparable {
private static final long serialVersionUID = -1L;
int size;
transient MemorySegment memorySegment;
public F32FloatTensor() {}
public F32FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
float getFloat(int index) {
assert 0 <= index && index < size;
return readFloat(memorySegment, index * 4);
}
@Override
void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies<Float> species, int offset) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
GGMLType type() {
return GGMLType.F32;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(size);
out.writeLong(memorySegment.byteSize());
out.write(memorySegment.toArray(ValueLayout.JAVA_BYTE));
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
size = in.readInt();
long bs = in.readLong();
memorySegment = Arena.ofAuto().allocate(bs, 1);
for(int i = 0; i < bs; i++)
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte)(in.read() & 0xFF));
}
@Override
public int compareTo(Object o) {
for(int i = 0; i < memorySegment.byteSize(); i++) {
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) > ((F32FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return 1;
if(memorySegment.get(ValueLayout.JAVA_BYTE,i) < ((F32FloatTensor)o).memorySegment.get(ValueLayout.JAVA_BYTE, i))
return -1;
}
return 0;
}
}
final class ArrayFloatTensor extends FloatTensor implements Externalizable, Comparable {
float[] values;
public ArrayFloatTensor() {}
ArrayFloatTensor(float[] values) {
this.values = values;
}
public static FloatTensor allocate(int... dims) {
int numberOfElements = FloatTensor.numberOfElements(dims);
return new ArrayFloatTensor(new float[numberOfElements]);
}
@Override
public int size() {
return values.length;
}
@Override
public float getFloat(int index) {
return values[index];
}
@Override
public void setFloat(int index, float value) {
values[index] = value;
}
@Override
public GGMLType type() {
return GGMLType.F32;
}
@Override
public FloatTensor fillInPlace(int thisOffset, int size, float value) {
Arrays.fill(values, thisOffset, thisOffset + size, value);
return this;
}
@Override
public FloatVector getFloatVector(VectorSpecies<Float> species, int index) {
if (!USE_VECTOR_API) {
throw new UnsupportedOperationException();
}
return FloatVector.fromArray(species, values, index);
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(values.length);
for(float v: values)
out.writeFloat(v);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
int vsize = in.readInt();
values = new float[vsize];
for(int i = 0; i < vsize; i++)
values[i]= in.readFloat();
}
@Override
public int compareTo(Object o) {
return Arrays.compare(values,((ArrayFloatTensor)o).values);
}
}
final class RoPE {
/**
* For GPT2 vocab
* @param contextLength
* @param headSize
* @param theta
* @param ropeScaling
* @param scaleFactor
* @param loFreqFactor
* @param hiFreqFactor
* @param oldContextLength
* @return
*/
public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int headSize, double theta,
boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];
int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
if (ropeScaling) {
// Llama 3.1 scaling
float loFreqWavelen = oldContextLength / loFreqFactor;
float hiFreqWavelen = oldContextLength / hiFreqFactor;
float wavelen = (float) (2.0 * Math.PI / freq);
if (wavelen < hiFreqWavelen) {
freq = freq;
} else if (wavelen > loFreqWavelen) {
freq = freq / scaleFactor;
} else {
float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor);
freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq;
}
}
float val = pos * freq;
cr[n] = (float) Math.cos(val);
ci[n] = (float) Math.sin(val);
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}
/**
* for Llama vocab
* @param contextLength
* @param headSize
* @param theta
* @return
*/
public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int headSize, double theta) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];
int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
float val = pos * freq;
cr[n] = (float) Math.cos(val);
ci[n] = (float) Math.sin(val);
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}
}
record Vocabulary(String[] tokens, float[] scores, Map<String, Integer> tokenToIndex) {
public Vocabulary(String[] vocabulary, float[] scores) {
this(vocabulary, scores,
IntStream.range(0, vocabulary.length)
.boxed()
.collect(Collectors.toMap(i -> vocabulary[i], i -> i))
);
}
public String get(int tokenIndex) {
return tokens[tokenIndex];
}
public OptionalInt getIndex(String token) {
Integer value = tokenToIndex.get(token);
return value != null ? OptionalInt.of(value) : OptionalInt.empty();
}
public int size() {
return tokens.length;
}
/**
* Added from Mistral Vocabulary - Groff
* @param tokenIndex
* @return
*/
public float getScore(int tokenIndex) {
return scores[tokenIndex];
}
public boolean scoresNull() {
return scores == null;
}
}
@FunctionalInterface
interface Sampler {
int sampleToken(FloatTensor logits);
Sampler ARGMAX = FloatTensor::argmax;
}
record CategoricalSampler(RandomGenerator rng) implements Sampler {
@Override
public int sampleToken(FloatTensor logits) {
// sample index from probabilities (they must sum to 1!)
float random0to1 = rng.nextFloat(1f);
float cdf = 0.0f;
for (int i = 0; i < logits.size(); i++) {
cdf += logits.getFloat(i);
if (random0to1 < cdf) {
return i;
}
}
return logits.size() - 1; // in case of rounding errors
}
}
final class ToppSampler implements Sampler {
final int[] indices;
final float topp;
final RandomGenerator rng;
public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) {
this.indices = new int[maxNumberOfElements];
this.topp = topp;
this.rng = rng;
}
static void swap(int[] array, int from, int to) {
int tmp = array[from];
array[from] = array[to];
array[to] = tmp;
}
static void siftDown(int[] array, int from, int n, Comparator<Integer> comparator) {
int prev = from, next;
while ((next = 2 * prev + 1) < n) {
int r = 2 * prev + 2;
if (r < n && comparator.compare(array[r], array[next]) < 0) {
next = r;
}
if (comparator.compare(array[next], array[prev]) < 0) {
swap(array, prev, next);
prev = next;
} else {
break;
}
}
}
@Override
public int sampleToken(FloatTensor logits) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
Comparator<Integer> comparator = Comparator.comparingDouble(logits::getFloat).reversed();
int n = logits.size();
int head = 0;
int tail = n - 1;
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < indices.length; i++) {
if (logits.getFloat(i) >= cutoff) {
indices[head++] = i;
} else {
indices[tail--] = i;
}
}
int n0 = head;
// build heap O(n0)
for (int i = n0 / 2 - 1; i >= 0; --i) {
siftDown(indices, i, n0, comparator);
}
// truncate the list where cumulative probability of the largest k elements exceeds topp
// O(k lg n0)
float cumulativeProb = 0.0f;
int lastIndex = 0;
for (int i = n0 - 1; i >= 0; i--) {
swap(indices, 0, i);
cumulativeProb += logits.getFloat(indices[i]);
if (cumulativeProb > topp) {
lastIndex = i;
break; // we've exceeded topp by including lastIndex
}
siftDown(indices, 0, i - 1, comparator);
}
// sample from the truncated list
float r = rng.nextFloat(1f) * cumulativeProb;
float cdf = 0.0f;
for (int i = n0 - 1; i >= lastIndex; i--) {
cdf += logits.getFloat(indices[i]);
if (r < cdf) {
return indices[i];
}
}
return indices[lastIndex]; // in case of rounding errors
}
}
interface ChatFormatInterface {
public TokenizerInterface getTokenizer();
public Set<Integer> getStopTokens();
public List<Integer> encodeHeader(ChatFormat.Message message);
public List<Integer> encodeMessage(ChatFormat.Message message);
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog);
public int getBeginOfText();
}
/**
* Utility tailored for Llama 3 instruct prompt format.
*/
class ChatFormat implements ChatFormatInterface {
final Tokenizer tokenizer;
final int beginOfText;
final int endHeader;
final int startHeader;
final int endOfTurn;
final int endOfText;
final int endOfMessage;
final Set<Integer> stopTokens;
public ChatFormat(TokenizerInterface tokenizer) {
this.tokenizer = (Tokenizer)tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.beginOfText = specialTokens.get("<|begin_of_text|>");
this.startHeader = specialTokens.get("<|start_header_id|>");
this.endHeader = specialTokens.get("<|end_header_id|>");
this.endOfTurn = specialTokens.get("<|eot_id|>");
this.endOfText = specialTokens.get("<|end_of_text|>");
this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1
this.stopTokens = Set.of(endOfText, endOfTurn);
}
@Override
public TokenizerInterface getTokenizer() {
return tokenizer;
}
@Override
public Set<Integer> getStopTokens() {
return stopTokens;
}
@Override
public int getBeginOfText() {
return beginOfText;
}
@Override
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(startHeader);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.add(endHeader);
tokens.addAll(this.tokenizer.encodeAsList("\n"));
return tokens;
}
@Override
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = this.encodeHeader(message);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(endOfTurn);
return tokens;
}
@Override
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(beginOfText);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
if (appendAssistantTurn) {
// Add the start of an assistant message for the model to complete.
tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
}
return tokens;
}
public record Message(ChatFormat.Role role, String content) {
}
public record Role(String name) {
public static ChatFormat.Role SYSTEM = new ChatFormat.Role("system");
public static ChatFormat.Role USER = new ChatFormat.Role("user");
public static ChatFormat.Role ASSISTANT = new ChatFormat.Role("assistant");
@Override
public String toString() {
return name;
}
}
}
/**
* Utility tailored for Mistral v0.3 instruct prompt format.
*/
final class MistralChatFormat implements ChatFormatInterface {
protected final MistralTokenizer tokenizer;
protected final int unknownToken;
protected final int beginOfText;
protected final int endOfText;
protected final int beginOfInstruction;
protected final int endOfInstruction;
protected final int toolCalls;
protected final int beginOfAvailableTools;
protected final int endOfAvailableTools;
protected final int beginOfToolResults;
protected final int endOfToolResults;
protected final int prefix;
protected final int middle;
protected final int suffix;
public MistralChatFormat(TokenizerInterface tokenizer) {
this.tokenizer = (MistralTokenizer)tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.unknownToken = specialTokens.get("<unk>");
this.beginOfText = specialTokens.get("<s>");
this.endOfText = specialTokens.get("</s>");
this.beginOfInstruction = specialTokens.get("[INST]");
this.endOfInstruction = specialTokens.get("[/INST]");
this.toolCalls = specialTokens.get("[TOOL_CALLS]");
this.beginOfAvailableTools = specialTokens.get("[AVAILABLE_TOOLS]");
this.endOfAvailableTools = specialTokens.get("[/AVAILABLE_TOOLS]");
this.beginOfToolResults = specialTokens.get("[TOOL_RESULTS]");
this.endOfToolResults = specialTokens.get("[/TOOL_RESULTS]");
// Only Codestral supports FIM tokens.
this.prefix = specialTokens.getOrDefault("[PREFIX]", unknownToken);
this.suffix = specialTokens.getOrDefault("[SUFFIX]", unknownToken);
this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken);
}
@Override
public TokenizerInterface getTokenizer() {
return tokenizer;
}
@Override
public Set<Integer> getStopTokens() {
return Set.of(endOfText);
}
@Override
public int getBeginOfText() {
return beginOfText;
}
public List<Integer> encodeMessage(String userMessage, boolean addHeader, boolean addFooter) {
List<Integer> tokens = new ArrayList<>();
if (addHeader) {
tokens.add(this.beginOfInstruction);
}
if (userMessage != null) {
tokens.addAll(this.tokenizer.encodeAsList(userMessage.strip()));
}
if (addFooter) {
tokens.add(endOfInstruction);
}
return tokens;
}
public List<Integer> encodeFillInTheMiddle(String prefix, String suffix) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.suffix);
tokens.addAll(tokenizer.encode(suffix));
tokens.add(this.prefix);
tokens.addAll(tokenizer.encode(prefix));
return tokens;
}
@Override
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.beginOfInstruction);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.add(endOfInstruction);
return tokens;
}
@Override
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(this.beginOfInstruction);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(endOfInstruction);
return tokens;
}
@Override
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(beginOfText);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
//if (appendAssistantTurn) {
// // Add the start of an assistant message for the model to complete.
// tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
//}
tokens.add(endOfText);
return tokens;
}
}
/**
* Utility tailored for the Chat Markup Language (ChatML) Qwen prompt format.
*/
class ChatMLFormat implements ChatFormatInterface {
protected final TokenizerInterface tokenizer;
protected final int imStart;
protected final int endOfText;
protected final int imEnd;
public ChatMLFormat(TokenizerInterface tokenizer) {
this.tokenizer = tokenizer;
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
this.imStart = specialTokens.get("<|im_start|>");
this.imEnd = specialTokens.get("<|im_end|>");
this.endOfText = specialTokens.get("<|endoftext|>");
}
public TokenizerInterface getTokenizer() {
return tokenizer;
}
public Set<Integer> getStopTokens() {
return Set.of(imEnd, endOfText);
}
@Override
public int getBeginOfText() {
return imStart;
}
public List<Integer> encodeHeader(ChatFormat.Message message) {
List<Integer> tokens = new ArrayList<>();
tokens.add(imStart);
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
tokens.addAll(this.tokenizer.encodeAsList("\n"));
return tokens;
}
public List<Integer> encodeMessage(ChatFormat.Message message) {
List<Integer> tokens = this.encodeHeader(message);
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
tokens.add(imEnd);
return tokens;
}
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
List<Integer> tokens = new ArrayList<>();
tokens.add(imStart);
for (ChatFormat.Message message : dialog) {
tokens.addAll(this.encodeMessage(message));
}
if (appendAssistantTurn) {
// Add the start of an assistant message for the model to complete.
tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
}
return tokens;
}
}
/**
* Support for AOT preloading of GGUF metadata with GraalVM's Native Image.
*
* <p>
* To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf}
* to the native-image builder command. At runtime, the preloaded model will be used
* iff the specified and preloaded file names (base name) match.
*/
final class AOT {
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {}
private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF"));
private static PartialModel preLoadGGUF(String modelPath) {
if (modelPath == null || modelPath.isEmpty()) {
return null;
}
try {
Path path = Path.of(modelPath);
if (!Files.exists(path) || !Files.isRegularFile(path)) {
throw new IllegalArgumentException("Cannot pre-load model: " + path);
}
GGUF gguf = GGUF.loadModel(path);
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
return new PartialModel(
path.getFileName().toString(),
ModelLoader.loadModel(fileChannel, gguf, Llama3.Options.DEFAULT_MAX_TOKENS, false),
gguf.getTensorDataOffset(),
gguf.getTensorInfos()
);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Tries to reuse a compatible AOT preloaded model.
* The file name (base name) must match with the preloaded file name.
* No checksum/hash is checked for performance reasons.
*/
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
if (preLoaded == null) {
return null; // no pre-loaded model stored
}
String optionsModel = modelPath.getFileName().toString();
String preLoadedModel = preLoaded.modelFileName();
if (!Objects.equals(optionsModel, preLoadedModel)) {
// Preloaded and specified model file names didn't match.
return null;
}
Llama baseModel = preLoaded.model();
try (var timer = Timer.log("Load tensors from pre-loaded model");
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
// Load only the tensors (mmap slices).
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
Llama.Weights weights = ModelLoader.loadGPT2Weights(tensorEntries, baseModel.configuration());
return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights);
}
}
}
/**
* Implementation of Super-Bit Locality-Sensitive Hashing.
* Super-Bit is an improvement of Random Projection LSH.
* It computes an estimation of cosine similarity.
*
* Super-Bit Locality-Sensitive Hashing
* Jianqiu Ji, Jianmin Li, Shuicheng Yan, Bo Zhang, Qi Tian
* http://papers.nips.cc/paper/4847-super-bit-locality-sensitive-hashing.pdf
* Advances in Neural Information Processing Systems 25, 2012
*
* Supported input types:
* - double[]
* @author original:Thibault Debatty
* @author Groff
*/
final class SuperBit implements java.io.Serializable, Comparable {
private static final long serialVersionUID = -1L;
private boolean[] sig;
private transient double[][] hyperplanes;
private static final int DEFAULT_CODE_LENGTH = 10000;
/**
* Initialize SuperBit algorithm.
* Super-Bit depth n must be [1 .. d] and number of Super-Bit l in [1 ..
* The resulting code length k = n * l
* The K vectors are orthogonalized in L batches of N vectors
*
* @param d data space dimension
* @param n Super-Bit depth [1 .. d]
* @param l number of Super-Bit [1 ..
*/
public SuperBit(final int d, final int n, final int l) {
this(d, n, l, new Random());
}
/**
* Initialize SuperBit algorithm.
* Super-Bit depth n must be [1 .. d] and number of Super-Bit l in [1 ..
* The resulting code length k = n * l
* The K vectors are orthogonalized in L batches of N vectors
*
* @param d data space dimension
* @param n Super-Bit depth [1 .. d]
* @param l number of Super-Bit [1 ..
* @param seed to use for the random number generator
*/
public SuperBit(final int d, final int n, final int l, final long seed) {
this(d, n, l, new Random(seed));
}
private SuperBit(final int d, final int n, final int l, final Random rand) {
if (d <= 0) {
throw new IllegalArgumentException("Dimension d must be >= 1");
}
if (n < 1 || n > d) {
throw new IllegalArgumentException(
"Super-Bit depth N must be 1 <= N <= d");
}
if (l < 1) {
throw new IllegalArgumentException(
"Number of Super-Bit L must be >= 1");
}
// Input: Data space dimension d, Super-Bit depth 1 <= N <= d,
// number of Super-Bit L >= 1,
// resulting code length K = N * L
// Generate a random matrix H with each element sampled independently
// from the normal distribution
// N (0, 1), with each column normalized to unit length.
// Denote H = [v1, v2, ..., vK].
int code_length = n * l;
double[][] v = new double[code_length][d];
Parallel.parallelFor(0, code_length, t -> {
double[] vector = new double[d];
for (int j = 0; j < d; j++) {
vector[j] = rand.nextGaussian();
}
normalize(vector);
v[t] = vector;
});
double[][] w = new double[code_length][d];
for (int i = 0; i <= l - 1; i++) {
for (int j = 1; j <= n; j++) {
java.lang.System.arraycopy(
v[i * n + j - 1],
0,
w[i * n + j - 1],
0,
d);
for (int k = 1; k <= (j - 1); k++) {
w[i * n + j - 1] = sub(
w[i * n + j - 1],
product(
dotProduct(
w[i * n + k - 1],
v[ i * n + j - 1]),
w[i * n + k - 1]));
}
normalize(w[i * n + j - 1]);
}
}
this.hyperplanes = w;
}
/**
* Initialize SuperBit algorithm.
* With code length K = 10000
* The K vectors are orthogonalized in d batches of 10000/d vectors
* The resulting mean error is 0.01
* @param d The size of the vector we are operating on
*/
public SuperBit(final int d) {
this(d, d, DEFAULT_CODE_LENGTH / d, 8675309);
}
/**
* Initialize SuperBit algorithm without parameters
* (used only for serialization).
*/
public SuperBit() {}
/**
* Compute the signature of this vector.
* @param vector
* @return
*/
public final boolean[] signature(final double[] vector) {
boolean[] sig = new boolean[this.hyperplanes.length];
for (int i = 0; i < this.hyperplanes.length; i++) {
sig[i] = (dotProduct(this.hyperplanes[i], vector) >= 0);
}
return sig;
}
/**
* Compute the signature of the given FloatTensor, set the encapsulated signature for serialization
* @param vector The target FloatTensor
*/
public final void signature(final FloatTensor vector) {
sig = new boolean[this.hyperplanes.length];
for (int i = 0; i < this.hyperplanes.length; i++) {
sig[i] = (dotProduct(this.hyperplanes[i], vector) >= 0);
}
}
public final boolean[] getSignature() {
return sig;
}
/**
* Compute the similarity between two signature, which is also an
* estimation of the cosine similarity between the two vectors.
* @param sig1
* @param sig2
* @return estimated cosine similarity
*/
public final double similarity(final boolean[] sig1, final boolean[] sig2) {
DoubleAdder agg = new DoubleAdder(); // Thread-safe accumulator
Parallel.parallelFor(0, sig1.length, t -> {
if (sig1[t] == sig2[t]) {
agg.add(1); // Efficient atomic addition
}
});
double sim = agg.sum() / sig1.length; // Use .sum() instead of .get()
return Math.cos((1 - sim) * Math.PI);
}
/**
* Get the hyperplanes coefficients used to compute signatures.
* @return
*/
public final double[][] getHyperplanes() {
return this.hyperplanes;
}
/**
* Computes the cosine similarity, computed as v1 dot v2 / (|v1| * |v2|).
* Cosine similarity of two vectors is the cosine of the angle between them.
* It ranges between -1 and +1
*
* @param v1
* @param v2
* @return
*/
public static double cosineSimilarity(final double[]v1, final double[] v2) {
return dotProduct(v1, v2) / (norm(v1) * norm(v2));
}
private static double[] product(final double x, final double[] v) {
double[] r = new double[v.length];
Parallel.parallelFor(0, v.length, t -> {
r[t] = x * v[t];
});
return r;
}
private static double[] sub(final double[] a, final double[] b) {
double[] r = new double[a.length];
Parallel.parallelFor(0, a.length, t -> {
r[t] = a[t] - b[t];
});
return r;
}
private static void normalize(final double[] vector) {
final double norm = norm(vector);
Parallel.parallelFor(0, vector.length, t -> vector[t] /= norm);
}
/**
* Returns the norm L2. sqrt(sum_i(v_i^2))
* @param v
* @return
*/
private static double norm(final double[] v) {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v.length, t -> agg.add(v[t] * v[t]));
return Math.sqrt(agg.sum());
}
private static double dotProduct(final double[] v1, final double[] v2) {
if (v1.length < 10_000) { // Adjust threshold based on benchmarking
return IntStream.range(0, v1.length).mapToDouble(t -> v1[t] * v2[t]).sum();
} else {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v1.length, t -> agg.add(v1[t] * v2[t]));
return agg.sum();
}
}
private static double dotProduct(final double[] v1, final FloatTensor v2) {
if (v1.length < 10_000) { // Adjust threshold based on benchmarking
return IntStream.range(0, v1.length).mapToDouble(t -> v1[t] * v2.getFloat(t)).sum();
} else {
DoubleAdder agg = new DoubleAdder();
Parallel.parallelFor(0, v1.length, t -> agg.add(v1[t] * v2.getFloat(t)));
return agg.sum();
}
}
@Override
public int compareTo(Object o) {
return Double.compare(similarity(this.sig, ((SuperBit)o).sig), 1.0); // Ensures proper ordering
}
}
---------- END SOURCE ----------