diff --git a/set-detect-notify/__main__.py b/set-detect-notify/__main__.py index b611370..9f166ec 100644 --- a/set-detect-notify/__main__.py +++ b/set-detect-notify/__main__.py @@ -67,6 +67,14 @@ def main(): help="The confidence threshold to use", ) + argparser.add_argument( + "--detect-object", + nargs="*", + default=[], + type=str, + help="The object(s) to detect. Must be something the model is trained to detect", + ) + stream_source = argparser.add_mutually_exclusive_group() stream_source.add_argument( "--url", @@ -176,8 +184,13 @@ def main(): # "first_detection_time": None, "last_notification_time": None, } + # Also, make sure that the objects to detect are in the list of object_names + # If it isn't, print a warning + for obj in args.detect_object: + if obj not in object_names: + print(f"Warning: {obj} is not in the list of objects the model can detect!") + for box in r.boxes: - # Get the name of the object class_id = r.names[box.cls[0].item()] # Get the coordinates of the object @@ -192,8 +205,10 @@ def main(): # print("---") # Now do stuff (if conf > 0.5) - if conf < args.confidence_threshold: - # If the confidence is less than 0.5, then SKIP!!!! + if conf < args.confidence_threshold or (class_id not in args.detect_object and args.detect_object != []): + # If the confidence is too low + # or if the object is not in the list of objects to detect and the list of objects to detect is not empty + # then skip this iteration continue