|
|
@@ -69,7 +69,8 @@ class TrainJittorConfig: |
|
|
|
shuffle: bool = True |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("driver,device", [("jittor", None)]) |
|
|
|
@pytest.mark.parametrize("driver", ["jittor"]) |
|
|
|
@pytest.mark.parametrize("device", ["cpu", 1]) |
|
|
|
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) |
|
|
|
@pytest.mark.jittor |
|
|
|
def test_trainer_jittor( |
|
|
@@ -134,4 +135,5 @@ def test_trainer_jittor( |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
# test_trainer_jittor("jittor", None, [RichCallback(100)]) |
|
|
|
# test_trainer_jittor("jittor", 1, [RichCallback(100)]) |
|
|
|
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 |