Skip to content

usmanraza121/T-Net

Repository files navigation

T-Net_Pytorch

Quick Start

1. Available Architectures

Specify the model architecture with '--model ARCH_NAME' and set the output stride using '--output_stride OUTPUT_STRIDE'.

2. Load the pretrained model:

model.load_state_dict( torch.load( CKPT_PATH )['model_state']  )

3. Visualize segmentation outputs:

outputs = model(images)
preds = outputs.max(1)[1].detach().cpu().numpy()
colorized_preds = val_dst.decode_target(preds).astype('uint8') # To RGB images, (N, H, W, 3), ranged 0~255, numpy array
# Do whatever you like here with the colorized segmentation maps
colorized_preds = Image.fromarray(colorized_preds[0]) # to PIL Image

5. Prediction

Single image:

python predict.py --input datasets/data/cityscapes/leftImg8bit/train/bremen/bremen_000000_000019_leftImg8bit.png  --dataset cityscapes --model TNet --ckpt checkpoints/best_TNet_cityscapes_os16.pth --save_val_results_to test_results

Image folder:

python predict.py --input datasets/data/cityscapes/leftImg8bit/train/bremen  --dataset cityscapes --model TNet --ckpt checkpoints/best_TNet_cityscapes_os16.pth --save_val_results_to test_results
pip install -r requirements.txt

6. Prepare Datasets

Cityscapes

1. Download cityscapes and extract it to 'datasets/data/cityscapes'

/datasets
    /data
        /cityscapes
            /gtFine
            /leftImg8bit

2. Train your model on Cityscapes

python main.py --model TNet --dataset cityscapes  --gpu_id 0  --lr 0.1  --crop_size 640 --batch_size 4 --output_stride 16 --data_root 

Releases

No releases published

Packages

No packages published

Languages