Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问关于pytorch版本该如何修改测试自己的数据集? #43

Open
chajnoven opened this issue Nov 1, 2019 · 2 comments
Open
Labels
自定义数据 如何使用自定义数据 调优 如何提高模型的精度

Comments

@chajnoven
Copy link

比如自己测试集在“当前目录的test的文件夹”里面

@ypwhs
Copy link
Owner

ypwhs commented Nov 1, 2019

class CaptchaDataset(Dataset):
    def __init__(self, characters, length, width, height, input_length, label_length):
        super(CaptchaDataset, self).__init__()
        self.characters = characters
        self.length = length
        self.width = width
        self.height = height
        self.input_length = input_length
        self.label_length = label_length
        self.n_class = len(characters)
        self.generator = ImageCaptcha(width=width, height=height)

    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        random_str = ''.join([random.choice(self.characters[1:]) for j in range(self.label_length)])
        image = to_tensor(self.generator.generate_image(random_str))
        target = torch.tensor([self.characters.find(x) for x in random_str], dtype=torch.long)
        input_length = torch.full(size=(1, ), fill_value=self.input_length, dtype=torch.long)
        target_length = torch.full(size=(1, ), fill_value=self.label_length, dtype=torch.long)
        return image, target, input_length, target_length

实现一个上面的自定义 Dataset 类, init 改为输入一个文件夹以及必要的信息,然后在 getitem 里读图,只要下面的代码能跑通,你就可以使用此 dataset 去测试。

dataset = CaptchaDataset(characters, 1, width, height, n_input_length, n_len)
image, target, input_length, label_length = dataset[0]
print(''.join([characters[x] for x in target]), input_length, label_length)
to_pil_image(image)

@chajnoven
Copy link
Author

chajnoven commented Nov 1, 2019 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
自定义数据 如何使用自定义数据 调优 如何提高模型的精度
Projects
None yet
Development

No branches or pull requests

2 participants