Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the effect of replacing background with U-2-Net is not as expected, anyone can offer a help #377

Open
weiweiwang opened this issue Feb 2, 2024 · 0 comments

Comments

@weiweiwang
Copy link

Env

model: u2net downloaded from the link in the repository README:https://pan.baidu.com/s/1WjwyEwDiaUjBbx_QxcXBwQ

Test input

image: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/gJT2UhwWmcHgM6ep.jpg
background: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/DCAqTPHo7Advhmrv.jpg
current result: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/RFD5jxgohUXX4fAX.jpg

Test Method

  1. I modified the u2net_test.py(as below) and place the image in the folder: test_images
  2. write replace background images to folder: test_data/rbg
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp
    # model_name = 'u2netp'  # u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)

        output_filename = save_output(img_name_list[i_test], pred, prediction_dir)
        
        ########## modification comes here ##########
        predict = pred
        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()
        input_image_file_path = img_name_list[i_test]
        image = cv2.imread(input_image_file_path)
        background = cv2.imread("test_data/bg-05.jpg")
        background = cv2.resize(background, (image.shape[1], image.shape[0]))
        im = Image.fromarray(predict_np * 255).convert('RGB')
        imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
        data = np.asarray(imo, dtype="int32")
        condition = data > 0.98 * 255
        output_image = np.where(condition, image, background)
        cv2.imwrite(f"test_data/rbg/{os.path.basename(input_image_file_path)}", output_image)

        del d1, d2, d3, d4, d5, d6, d7

I'm a newbie at this field, could anyone offer a help, thanks a lot~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant