This commit is contained in:
slashtechno 2023-10-13 18:16:55 -05:00
parent 000692a7af
commit 556a90da3c
Signed by: slashtechno
GPG Key ID: 8EC1D9D9286C2B17
3 changed files with 46 additions and 26 deletions

View File

@ -57,3 +57,7 @@ ipykernel = "^6.25.2"
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.ruff]
# More than the default (88) of `black` to make comments less of a headache
line-length = 120

View File

@ -50,9 +50,11 @@ def main():
help="The scale to run the detection at, default is 0.25", help="The scale to run the detection at, default is 0.25",
) )
argparser.add_argument( argparser.add_argument(
'--view-scale', "--view-scale",
# Set it to the env VIEW_SCALE if it isn't blank, otherwise set it to 0.75 # Set it to the env VIEW_SCALE if it isn't blank, otherwise set it to 0.75
default=os.environ['VIEW_SCALE'] if 'VIEW_SCALE' in os.environ and os.environ['VIEW_SCALE'] != '' else 0.75, # noqa: E501 default=os.environ["VIEW_SCALE"]
if "VIEW_SCALE" in os.environ and os.environ["VIEW_SCALE"] != ""
else 0.75, # noqa: E501
type=float, type=float,
help="The scale to view the detection at, default is 0.75", help="The scale to view the detection at, default is 0.75",
) )
@ -188,7 +190,9 @@ def main():
# If it isn't, print a warning # If it isn't, print a warning
for obj in args.detect_object: for obj in args.detect_object:
if obj not in object_names: if obj not in object_names:
print(f"Warning: {obj} is not in the list of objects the model can detect!") print(
f"Warning: {obj} is not in the list of objects the model can detect!"
)
for box in r.boxes: for box in r.boxes:
# Get the name of the object # Get the name of the object
@ -205,13 +209,14 @@ def main():
# print("---") # print("---")
# Now do stuff (if conf > 0.5) # Now do stuff (if conf > 0.5)
if conf < args.confidence_threshold or (class_id not in args.detect_object and args.detect_object != []): if conf < args.confidence_threshold or (
class_id not in args.detect_object and args.detect_object != []
):
# If the confidence is too low # If the confidence is too low
# or if the object is not in the list of objects to detect and the list of objects to detect is not empty # or if the object is not in the list of objects to detect and the list of objects to detect is not empty
# then skip this iteration # then skip this iteration
continue continue
# Add the object to the list of objects to plot # Add the object to the list of objects to plot
plot_boxes.append( plot_boxes.append(
{ {

View File

@ -1,5 +1,7 @@
import cv2 import cv2
import numpy as np import numpy as np
def plot_label( def plot_label(
# list of dicts with each dict containing a label, x1, y1, x2, y2 # list of dicts with each dict containing a label, x1, y1, x2, y2
boxes: list = None, boxes: list = None,
@ -19,9 +21,15 @@ def plot_label(
# Image # Image
view_frame, view_frame,
# Start point # Start point
(int(thing["x1"] * (run_scale/view_scale)), int(thing["y1"] * (run_scale/view_scale))), (
int(thing["x1"] * (run_scale / view_scale)),
int(thing["y1"] * (run_scale / view_scale)),
),
# End point # End point
(int(thing["x2"] * (run_scale/view_scale)), int(thing["y2"] * (run_scale/view_scale))), (
int(thing["x2"] * (run_scale / view_scale)),
int(thing["y2"] * (run_scale / view_scale)),
),
# Color # Color
(0, 255, 0), (0, 255, 0),
# Thickness # Thickness
@ -33,7 +41,10 @@ def plot_label(
# Text # Text
thing["label"], thing["label"],
# Origin # Origin
(int(thing["x1"] * (run_scale/view_scale)), int(thing["y1"] * (run_scale/view_scale))), (
int(thing["x1"] * (run_scale / view_scale)),
int(thing["y1"] * (run_scale / view_scale)),
),
# Font # Font
font, font,
# Font Scale # Font Scale
@ -41,6 +52,6 @@ def plot_label(
# Color # Color
(0, 255, 0), (0, 255, 0),
# Thickness # Thickness
1 1,
) )
return view_frame return view_frame