Showing content from https://github.com/jrudolph/llama2.scala below:
jrudolph/llama2.scala: Inference Llama 2 in Scala with AVX2 kernels in C (A port of llama2.c from Andrej Karpathy)
A Scala 2 port of Andrej Karpathy's llama2.c
This is a Scala port of Andrej Karpathy's llama2.c, a bare bones implementation to run inference of models with a Llama-like transformer-based LLM architecture.
The code expects tokenizer.bin
and stories15M.bin
in the current directory.
This started as a port of the original code in pure Scala. Later, more high-level abstractions were added and low-level C kernels with AVX2 intrinsics to speed up matrix multiplication.
- Two implementations of the model architecture are available:
Llama2SimpleTransformer
which is a direct port of the original C code
Llama2TensorTransformer
which uses a Tensor
abstraction to make the code more readable
- Different matrix multiplication kernels:
ScalaMathImplementation
(direct port of llama2.c in pure Scala)
AVX2MathImplementation
(using JNI to call kernels written with C SIMD intrinsics)
- Models are mapped into memory to avoid loading them into the JVM heap
- Quantization modes:
- use the weights as given in the model
- Q8: quantize weights after loading to 8 bits (all but rmsnorm)
- Q4: quantize weights after loading to 4 bits (all but rmsnorm)
- Multi-threading:
AVX2MathImplementation
uses OpenMP
- Support for loading ggml models
- only weights in Q4_0 and FP32 are supported
- scala-native support (mostly broken right now)
Current numbers run with version 08c65d04 on my AMD Ryzen 7 4800H laptop with GraalVM JDK 17.
Implementations:
- Scala =
Llama2TensorTransformer
with ScalaMathImplementation
- native-avx2 =
Llama2TensorTransformer
with AVX2MathImplementation
(using JNI to call kernels written with C SIMD intrinsics)
- llama2.c = as of 94a3a5e0
- llama.cpp = as of d783f798
- scala-native = Using scala-native 0.4.14 with the
Llama2SimpleTransformer
implementation
Notes:
- Approximate speedups are:
- pure Scala -> AVX2: > 10x
- FP32 -> Q8/Q4 (in Scala): same speed
- FP32 -> Q8 (AVX2): ~ 2x
- Q8 -> Q4 (AVX2) on one thread: same speed
- Q4 1 thread -> 6 threads on small models: ~ 2x
- Q4 1 thread -> 6 threads on large models: ~ 3x
- The pure Scala mode GraalVM JDK 17 is only competitive with a llama2.c version compiled with
-O3
. Using -Ofast
on C already makes a huge difference. Would be interesting to see the exact differences between JIT compiled code and gcc output with -Ofast
. Not sure if something like -Ofast
(using less strict FP math) is possible on the JVM.
- Using (i.e. mostly adapting from llama.cpp) kernels in C with SIMD intrinsics and calling them with JNI from Scala makes a huge difference. It is easy to do locally, but, of course, much harder to do in a portable way.
- As expected, quantization gives another boost. Interesting that it is more pronounced when multi-threading is enabled.
- OMP-based multithreading is simple to use from C and helps a lot. Scaling is not perfect, with benefits diminishing sharply after using more than 6 (of 8) threads.
- Multithreading is interesting, as the task units are quite small (one matrix multiplication) and overheads can be significant.
- Quantization only helps with SIMD optimization. Only SIMD will give access to byte-wise (int8) operations and decreasing the data type size will increase the number of lanes per vector with the same factor. It is unclear why going from 32-bit to 8-bit gives only a 2x speedup while being able to run 4x more operations in parallel. One explanation could be that you need more instructions because of the added complexity of quantization.
Model Quantization Implementation Threads tok / s stories15M.bin Q4 native-avx2 1 494 stories15M.bin Q4 native-avx2 6 931 stories15M.bin Q4 Scala 1 65 stories15M.bin Q8 native-avx2 1 533 stories15M.bin Q8 native-avx2 6 800 stories15M.bin Q8 Scala 1 57 stories15M.bin none native-avx2 1 374 stories15M.bin none native-avx2 6 677 stories15M.bin none Scala 1 66 stories15M.bin none scala-native vanilla 1 14 stories15M.bin none scala-native (native mmaps) 1 50 stories42M.bin Q4 native-avx2 1 223 stories42M.bin Q4 native-avx2 6 497 stories42M.bin Q4 Scala 1 24 stories42M.bin Q8 native-avx2 1 229 stories42M.bin Q8 native-avx2 6 407 stories42M.bin Q8 Scala 1 22 stories42M.bin none native-avx2 1 137 stories42M.bin none native-avx2 6 243 stories42M.bin none Scala 1 24 stories42M.bin none llama2.c / run 1 21 stories42M.bin none llama2.c / runfast 1 69 stories42M.bin none llama2.c / runomp 1 98 stories42M.bin none llama2.c / runomp 6 195 stories110M.bin Q4 native-avx2 1 95 stories110M.bin Q4 native-avx2 6 239 stories110M.bin Q4 Scala 1 9.6 stories110M.bin Q8 native-avx2 1 99 stories110M.bin Q8 native-avx2 6 183 stories110M.bin Q8 Scala 1 8.4 stories110M.bin none native-avx2 1 50 stories110M.bin none native-avx2 6 85 stories110M.bin none Scala 1 8.9 stories110M.bin none llama2.c / runomp 6 77 llama2_7b.bin Q4 native-avx2 1 2.0 llama2_7b.bin Q4 native-avx2 6 6.5 llama2_7b.bin Q4 Scala 1 0.16 llama2_7b.bin Q8 native-avx2 1 1.9 llama2_7b.bin Q8 native-avx2 6 4.46 llama2_7b.bin Q8 Scala 1 0.14 llama-2-7b.ggmlv3.q4_0.bin as provided native-avx2 1 1.66 llama-2-7b.ggmlv3.q4_0.bin as provided native-avx2 6 6.71 llama-2-7b.ggmlv3.q4_0.bin as provided Scala 1 0.13 llama-2-7b.ggmlv3.q4_0.bin as provided llama.cpp 1 2.0 llama-2-7b.ggmlv3.q4_0.bin as provided llama.cpp 6 8.1
MIT
RetroSearch is an open source project built by @garambo
| Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4