diff --git a/demo/image_demo.py b/demo/image_demo.py index 015873506ce..32f969fb787 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -27,7 +27,7 @@ def main(): try: pretrained = args.checkpoint or True inferencer = ImageClassificationInferencer( - args.model, pretrained=pretrained) + args.model, device=args.device, pretrained=pretrained) except ValueError: raise ValueError( f'Unavailable model "{args.model}", you can specify find a model '