Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement ResNet? #1076

Open
voidwings opened this issue Jun 27, 2023 · 2 comments
Open

How to implement ResNet? #1076

voidwings opened this issue Jun 27, 2023 · 2 comments

Comments

@voidwings
Copy link

voidwings commented Jun 27, 2023

I try to write Resnet18 as this:

program.options_from_args()

from Compiler import ml

try:
    ml.set_n_threads(int(program.args[2]))
except:
    pass

get_data = lambda train, transform=None: torchvision.datasets.CIFAR10(
    root='/tmp', train=train, download=True, transform=transform)

import torchvision, numpy
data = []
for train in True, False:
    ds = get_data(train)
    # normalize to [-1,1] before input
    samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
    labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
    data += [(labels, samples)]

(training_labels, training_samples), (test_labels, test_samples) = data


import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)


class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)


class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
                                    RestNetBasicBlock(512, 512, 1))

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out


net = ResNet18()

# train for a bit
transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
ds = get_data(train=True, transform=transform)
optimizer = torch.optim.Adam(net.parameters(), amsgrad=True)
criterion = nn.CrossEntropyLoss()

for i, data in enumerate(torch.utils.data.DataLoader(ds, batch_size=128)):
    inputs, labels = data
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    ds = get_data(False, transform)
    total = correct_classified = 0
    for data in torch.utils.data.DataLoader(ds, batch_size=128):
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct_classified += (predicted == labels).sum().item()
    test_acc = (100 * correct_classified / total)
    print('Cleartext test accuracy of the network: %.2f %%' % test_acc)

layers = ml.layers_from_torch(net, training_samples.shape, 128, input_via=0)

optimizer = ml.SGD(layers)

optimizer.fit(
    training_samples,
    training_labels,
    epochs=int(1),
    batch_size=128,
    validation_data=(test_samples, test_labels),
    program=program,
    reset=False
)

The error is CompilerError: unknown PyTorch module: ResNet18.
It seems I can't pass a self-defined module to the Compiler. Is there any example of ResNet18 inference in MP-SPDZ?

@mkskeller
Copy link
Member

The PyTorch interface only supports sequential networks, but ResNet contains an addition and thus isn't sequential. We have implemented ResNet-50 inference, which you can run as follows from the MP-SPDZ root directory:

git clone https://github.com/mkskeller/EzPC
cd EzPC/Athos/Networks/ResNet
axel -a -n 5 -c --output ./PreTrainedModel http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
cd PreTrainedModel && tar -xvzf resnet_v2_fp32_savedmodel_NHWC.tar.gz && cd ..
python3 ResNet_main.py --runPrediction True --scalingFac 12 --saveImgAndWtData True
cd ../../../..
Scripts/fixed-rep-to-float.py EzPC/Athos/Networks/ResNet/ResNet_img_input.inp
Scripts/compile-emulate.py tf EzPC/Athos/Networks/ResNet/graphDef.bin 8

You can change the last line to the compile-run.sh -E <protocol>.

@AliceNCsyuk
Copy link

AliceNCsyuk commented Jul 4, 2023

I have implemented a simple training code for residual blocks in ml.py, and I hope it may bring some motivations for you. If anyone has implemented a complete ResNet training, I am really looking forward to it being open source.

class SimpleRes_Linear(DenseBase):

def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):

    if activation == 'id':
        self.activation_layer = None
    elif activation == 'relu':
        self.activation_layer = Relu([N, d, d_out])
    elif activation == 'square':
        self.activation_layer = Square([N, d, d_out])
    else:
        raise CompilerError('activation not supported: %s', activation)

    self.N = N
    self.d_in = d_in
    self.d_out = d_out
    self.d = d
    self.activation = activation
    self.X = MultiArray([N, d, d_in], sfix)
    self.Y = MultiArray([N, d, d_out], sfix)
    self.W = Tensor([d_in, d_out], sfix)
    self.b = sfix.Array(d_out)
    back_N = min(N, self.back_batch_size)
    self.nabla_Y = MultiArray([back_N, d, d_out], sfix)
    self.nabla_X = MultiArray([back_N, d, d_in], sfix)
    self.nabla_W = sfix.Matrix(d_in, d_out)
    self.nabla_b = sfix.Array(d_out)
    self.debug = debug
    l = self.activation_layer

    if l:
        self.f_input = l.X
        l.Y = self.Y
        l.nabla_Y = self.nabla_Y
    else:
        self.f_input = self.Y

def __repr__(self):
    return '%s(%s, %s, %s, activation=%s)' % \
        (type(self).__name__, self.N, self.d_in,
         self.d_out, repr(self.activation))

def reset(self):
    d_in = self.d_in
    d_out = self.d_out
    r = math.sqrt(6.0 / (d_in + d_out))
    print('Initializing dense weights in [%f,%f]' % (-r, r))
    self.W.randomize(-r, r)
    self.b.assign_all(0)

def input_from(self, player, raw=False):
    self.W.input_from(player, raw=raw)
    if self.input_bias:
        self.b.input_from(player, raw=raw)

def compute_f_input(self, batch):
    N = len(batch)
    assert self.d == 1
    if self.input_bias:
        prod = MultiArray([N, self.d, self.d_out], sfix)
    else:
        prod = self.f_input
    max_size = program.Program.prog.budget // self.d_out

    @multithread(self.n_threads, N, max_size)
    def _(base, size):
        X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
        prod.assign_part_vector(
            X_sub.direct_mul(self.W, indices=(
                batch.get_vector(base, size), regint.inc(self.d_in),
                regint.inc(self.d_in), regint.inc(self.d_out))), base)

    if self.input_bias:
        if self.d_out == 1:
            @multithread(self.n_threads, N)
            def _(base, size):
                v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)+self.X.expand_to_vector(0, size)
                self.f_input.assign_vector(v, base)
        else:
            @for_range_multithread(self.n_threads, 100, N)
            def _(i):
                v = prod[i].get_vector() + self.b.get_vector() + self.X.get_vector()
                self.f_input[i].assign_vector(v)
    progress('f input')

def _forward(self, batch=None):
    if batch is None:
        batch = regint.Array(self.N)
        batch.assign(regint.inc(self.N))
    self.compute_f_input(batch=batch)
    if self.activation_layer:
        self.activation_layer.forward(batch)
    if self.debug_output:
        print_ln('dense X %s', self.X.reveal_nested())
        print_ln('dense W %s', self.W.reveal_nested())
        print_ln('dense b %s', self.b.reveal_nested())
        print_ln('dense Y %s', self.Y.reveal_nested())
    if self.debug:
        limit = self.debug
        @for_range_opt(len(batch))
        def _(i):
            @for_range_opt(self.d_out)
            def _(j):
                to_check = self.Y[i][0][j].reveal()
                check = to_check > limit

                @if_(check)
                def _():
                    print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
                    print_ln('X %s', self.X[i].reveal_nested())
                    print_ln('W %s',
                             [self.W[k][j].reveal() for k in range(self.d_in)])

def backward(self, compute_nabla_X=True, batch=None):
    N = len(batch)
    d = self.d
    d_out = self.d_out
    X = self.X
    Y = self.Y
    W = self.W
    b = self.b
    nabla_X = self.nabla_X
    nabla_Y = self.nabla_Y
    nabla_W = self.nabla_W
    nabla_b = self.nabla_b

    if self.activation_layer:
        self.activation_layer.backward(batch)
        f_schur_Y = self.activation_layer.nabla_X
    else:
        f_schur_Y = nabla_Y

    if compute_nabla_X:
        @multithread(self.n_threads, N)
        def _(base, size):
            B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
            nabla_X.assign_part_vector(
                B.direct_mul_trans(W, indices=(regint.inc(size, base),
                                               regint.inc(self.d_out),
                                               regint.inc(self.d_out),
                                               regint.inc(self.d_in))),
                base)
            nabla_X[:]+=sfix.from_sint(1)
            print('res')

        if self.print_random_update:
            print_ln('backward %s', self)
            index = regint.get_random(64) % self.nabla_X.total_size()
            print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
                     index, self.nabla_X.to_array()[index].reveal())

        progress('nabla X')

    self.backward_params(f_schur_Y, batch=batch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants