Skip to content

Commit

Permalink
Fix X3D input shape based on version of the model
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Jan 12, 2024
1 parent c90f18d commit 93fda7a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def initialize_wrapper(dataset_name, model_name,
model_wrapper = TimmModelWrapper(model_name)
elif dataset_name == "coco":
os.environ['COCO_PATH'] = dataset_path
if model_name in ["yolov8n"]:
if model_name in ["yolov8n", "yolov8s"]:
from models.detection.coco import UltralyticsModelWrapper
model_wrapper = UltralyticsModelWrapper(model_name)
elif dataset_name == "camvid":
Expand Down
14 changes: 7 additions & 7 deletions models/action_recognition/ucf101.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import onnx
import os
import torch

import onnx
import torch
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.runner import load_checkpoint
from models.base import TorchModelWrapper
from mmengine.runner import Runner, load_checkpoint
from onnxsim import simplify

from models.base import TorchModelWrapper


class MmactionModelWrapper(TorchModelWrapper):
# https://github.com/open-mmlab/mmaction

def __init__(self, model_name, input_size=(1, 1, 3, 16, 256, 256), num_classes=101):
self.input_size = input_size
def __init__(self, model_name, num_classes=101):
self.input_size = (1, 1, 3, 16, 256, 256) if model_name == "x3d_m" else (1, 1, 3, 13, 182, 182)
self.num_classes = num_classes
super().__init__(model_name)

Expand Down

0 comments on commit 93fda7a

Please sign in to comment.