diff --git a/pyproject.toml b/pyproject.toml index ed0b5b6..63185be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ ml = [ "torch>=2.0.0", "transformers>=4.30.0", "datasets>=2.14.0", + "accelerate>=1.1.0", ] dev = [ "mypy>=1.8.0",