Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use other start method for multiprocessing #204

Open
Guriido opened this issue Feb 14, 2018 · 11 comments
Open

Cannot use other start method for multiprocessing #204

Guriido opened this issue Feb 14, 2018 · 11 comments
Assignees

Comments

@Guriido
Copy link

Guriido commented Feb 14, 2018

This issue is not inherent to chainermn, so I was confused where to submit it.
In the training example of ImageNet, I cannot run the test without removing the multiprocessing.set_start_method('forkserver') part before MultiprocessIterator creation. (addition suggered in this pull)
When I tried this kind of error occurred:

[c9e2450dbf72:01311] *** Process received signal ***
[c9e2450dbf72:01311] Signal: Segmentation fault (11)
[c9e2450dbf72:01311] Signal code: Address not mapped (1)
[c9e2450dbf72:01311] Failing at address: 0x28
[c9e2450dbf72:01311] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7f4ae2fae390]
[c9e2450dbf72:01311] [ 1] /usr/local/lib/openmpi/mca_pmix_pmix112.so(+0x2cfaa)[0x7f4addfeffaa]
[c9e2450dbf72:01311] [ 2] /usr/local/lib/libopen-pal.so.20(opal_libevent2022_event_base_loop+0x7f3)[0x7f4ae2949c73]
[c9e2450dbf72:01311] [ 3] /usr/local/lib/openmpi/mca_pmix_pmix112.so(+0x2abdd)[0x7f4addfedbdd]
[c9e2450dbf72:01311] [ 4] /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba)[0x7f4ae2fa46ba]
[c9e2450dbf72:01311] [ 5] /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d)[0x7f4ae2cda3dd]
[c9e2450dbf72:01311] *** End of error message ***
[ea1ac9392fd3:00263] 1 more process has sent help message help-opal-runtime.txt / opal_init:warn-fork
[ea1ac9392fd3:00263] Set MCA parameter "orte_base_help_aggregate" to 0 to see all help / error messages
Segmentation fault (core dumped)

After further investigation, the problem seems to be that any other start method than fork in OpenMPI environment leads to this error. I also tried not to use set_start_method() but the get_context() as suggested in python docs, but this would lead to the same result.
I provide the following sample code to reproduce the error:

import multiprocessing as mp
import argparse


def nothing():
    print('I do nothing')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='fork',
                        help='start methods: fork forkserver spawn')
    args = parser.parse_args()

    mp.set_start_method(args.method)

    for i in range(2):
        p = mp.Process(target=nothing, args=()).start()

The environment is as follows:

  • python 3.5.2
  • Open MPI 2.1.2 with infiniband
@keisukefukuda keisukefukuda self-assigned this Feb 14, 2018
@keisukefukuda
Copy link
Member

keisukefukuda commented Feb 14, 2018

Thanks for reporting the issue.
Could you provide the version of Open MPI
(I guess it's OpenMPI, not OpenMP) ?

@Guriido
Copy link
Author

Guriido commented Feb 14, 2018

Sorry for the typo, I will edit my message.
My version of Open MPI is 2.1.2

@keisukefukuda
Copy link
Member

keisukefukuda commented Feb 16, 2018

Hi @Guriido , Thanks for the info.

I tried your test script in several environments.
My environment has IB FDR on Ubuntu 14.04
(I will try 16.04 later)
I ran the script via mpiexec on 2 physical nodes to make sure IB is used.

Python MPI method Result
3.6.1 2.1.2 fork OK
3.6.1 2.1.2 forkserver OK
3.6.1 2.1.2 spawn OK
3.5.2 2.1.2 fork OK
3.5.2 2.1.2 forkserver OK
3.5.2 2.1.2 spawn OK
anaconda3-4.3.1 (3.6.0) 2.1.2 fork OK
anaconda3-4.3.1 (3.6.0) 2.1.2 forkserver OK
anaconda3-4.3.1 (3.6.0) 2.1.2 spawn OK
3.6.1 3.0.0 fork OK
3.6.1 3.0.0 forkserver OK
3.6.1 3.0.0 spawn OK

However, we are sure that the multiprocessing module sometimes causes a trouble and
that is why we put the set_start_method in the imagenet example.

Could you provide more information about your environment?

Thanks!

@Guriido
Copy link
Author

Guriido commented Feb 19, 2018

I am quite surprised all the test passed, maybe the problem is related to the linux distribution ?
My environment has also IB FDR but on Ubuntu 16.04.
I am not sure what else would be relevant about my environment, have you anything else in mind ?

About IB configuration:
CA 'mlx4_0'
CA type: MT4099
Firmware version: 2.40.7000

I understand the motivation behind the use of set_start_method , it may be troublesome for me too to keep using fork option.

@keisukefukuda
Copy link
Member

CA 'mlx4_0'
CA type: MT4099
Firmware version: 2.40.7000

It looks exactly the same as yours. 🤔
We had an issue in the past, so there should be some problems.
(I don't know why it didn't reproduce this time in our environment.)

BTW, what is your mpi4py version?
Maybe this is relevant.

@Guriido
Copy link
Author

Guriido commented Feb 20, 2018

As you could pass the test with a quite similar environment, I wanted to retest my script.
It is quite embarrassing, but I cannot reproduce the error any more.
I work with a virtual environment, and I didn't updated my docker image, so I have no clue about what is going on...

However, I still get errors using MultiprocessIterator with Mnist train script. (in chainermn/examples/mnist there is only SerialIterator implementation, so I provide sample code below)

What concerns me is that the call to set_start_method in this case, throws errors.
That is, even a call to set_start_method('fork') gives errors, with a little different output:

Click for details
[8aa66b4816ed:04152] *** Process received signal ***
[8aa66b4816ed:04152] Signal: Segmentation fault (11)
[8aa66b4816ed:04152] Signal code: Address not mapped (1)
[8aa66b4816ed:04152] Failing at address: 0x7fd6c8ce9970
[8aa66b4816ed:04152] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7fd799d51390]
[8aa66b4816ed:04152] [ 1] python3(PyDict_GetItem+0x7e)[0x58fa3e]
[8aa66b4816ed:04152] [ 2] python3(_PyObject_GenericGetAttrWithDict+0xb8)[0x57f688]
[8aa66b4816ed:04152] [ 3] python3(PyEval_EvalFrameEx+0x44a)[0x523eea]
[8aa66b4816ed:04152] [ 4] python3(PyEval_EvalFrameEx+0x49c4)[0x528464]
[8aa66b4816ed:04152] [ 5] python3(PyEval_EvalCodeEx+0x13b)[0x52dd1b]
[8aa66b4816ed:04152] [ 6] python3[0x4e3153]
[8aa66b4816ed:04152] [ 7] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [ 8] python3[0x4f40de]
[8aa66b4816ed:04152] [ 9] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [10] python3[0x54e829]
[8aa66b4816ed:04152] [11] python3[0x55835c]
[8aa66b4816ed:04152] [12] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [13] python3(PyEval_EvalFrameEx+0x4eb6)[0x528956]
[8aa66b4816ed:04152] [14] python3(PyEval_EvalFrameEx+0x49c4)[0x528464]
[8aa66b4816ed:04152] [15] python3(PyEval_EvalFrameEx+0x49c4)[0x528464]
[8aa66b4816ed:04152] [16] python3(PyEval_EvalFrameEx+0x49c4)[0x528464]
[8aa66b4816ed:04152] [17] python3(PyEval_EvalCodeEx+0x13b)[0x52dd1b]
[8aa66b4816ed:04152] [18] python3[0x4e3267]
[8aa66b4816ed:04152] [19] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [20] python3[0x4f40de]
[8aa66b4816ed:04152] [21] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [22] python3[0x54e829]
[8aa66b4816ed:04152] [23] python3[0x55835c]
[8aa66b4816ed:04152] [24] python3(PyObject_Call+0x47)[0x5b5da7]
[8aa66b4816ed:04152] [25] python3(PyEval_EvalFrameEx+0x4eb6)[0x528956]
[8aa66b4816ed:04152] [26] python3[0x52cf19]
[8aa66b4816ed:04152] [27] python3(PyEval_EvalFrameEx+0x509f)[0x528b3f]
[8aa66b4816ed:04152] [28] python3(PyEval_EvalFrameEx+0x49c4)[0x528464]
[8aa66b4816ed:04152] [29] python3(PyEval_EvalCodeEx+0x13b)[0x52dd1b]
[8aa66b4816ed:04152] *** End of error message ***

about modules versions:

  • mpi4py 3.0.0
  • chainermn 1.2.0
  • chainer 4.0.0b3

mnist + multiprocess iterator sample code:

Click to expand
#!/usr/bin/env python
from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI

import chainermn
import multiprocessing


class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            # the size of the inputs to each layer will be inferred
            l1=L.Linear(784, n_units),  # n_in -> n_units
            l2=L.Linear(n_units, n_units),  # n_units -> n_units
            l3=L.Linear(n_units, n_out),  # n_units -> n_out
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', action='store_true',
                        help='Use GPU')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--loaderjob', '-j', type=int, default=0,
                        help='Number of parallel data loading processes')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.
    if args.gpu:
        if args.communicator == 'naive':
            print("Error: 'naive' communicator does not support GPU.\n")
            exit(-1)
        comm = chainermn.create_communicator(args.communicator)
        device = comm.intra_rank
    else:
        if args.communicator != 'naive':
            print('Warning: using naive communicator '
                  'because only naive supports CPU-only execution')
        comm = chainermn.create_communicator('naive')
        device = -1

    if comm.mpi_comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num unit: {}'.format(args.unit))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    model = L.Classifier(MLP(args.unit, 10))
    if device >= 0:
        chainer.cuda.get_device(device).use()
        model.to_gpu()

    # Create a multi node optimizer from a standard Chainer optimizer.
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        train, test = chainer.datasets.get_mnist()
    else:
        train, test = None, None
    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    test = chainermn.scatter_dataset(test, comm, shuffle=True)


    if args.loaderjob == 0:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
        test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                     repeat=False, shuffle=False)
    else:
        multiprocessing.set_start_method('forkserver')
        train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize, n_processes=args.loaderjob)
        test_iter = chainer.iterators.MultiprocessIterator(test, args.batchsize, repeat=False, n_processes=args.loaderjob)

    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Create a multi node evaluator from a standard Chainer evaluator.
    evaluator = extensions.Evaluator(test_iter, model, device=device)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    trainer.extend(evaluator)

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        trainer.extend(extensions.dump_graph('main/loss'))
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.PrintReport(
            ['epoch', 'main/loss', 'validation/main/loss',
             'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()


if __name__ == '__main__':
    main()

@keisukefukuda
Copy link
Member

Thanks for the code.

I tried your MNIST code with the following environment and start_method, and they were all OK.

image

They are all mpi4py 3.0.0, 2 processes on physically different nodes (so IB is used)

However, again, we know that there is (at least was) an issue.
We still occasionally have a crash under docker environment.

Hmm...

@Guriido
Copy link
Author

Guriido commented Feb 21, 2018

Thanks again for all the tests.
From the beginning, all my tests were done under docker environment. From your remark, I can suppose using docker leads to these issues..

Also, the crash occurs even when all the processes are running from the same node. (so supposedly IB is not in cause)

What kind of crash do you experience under docker environment ?
Also what image do you use ? (the one from #71 ?)

@keisukefukuda
Copy link
Member

keisukefukuda commented Feb 21, 2018

We use an in-house Docker images, so I'm afraid I cannot give you information on that, unfortunately.

However, we experienced the "start_method" issue since the early stage of ChainerMN development and we were not using Docker back then. So the issue can happen in non-Docker environment anyways.

Let me investigate the issue a bit more with ImageNet and other codes.

@Guriido
Copy link
Author

Guriido commented Apr 26, 2018

I have made further tests concerning the issues above.
At first, a half workaround I found is to avoid using communications when scattering datasets. This is possible if all workers dispose of a local copy of the dataset (or use a nfs setup), and the same seed is shared among all workers. Hereafter, I refer to this processus with "no_comm".

I made all the tests with the same environment as above, with an altered version of mnist training script.

Here is the report of the tests:

scatter train_iterator test_iterator start_method result error_type
classic Serial Serial / OK /
classic MultiProcess Serial fork failure A
classic MultiProcess Serial forkserver failure B
classic MultiProcess Serial spawn failure B
no_comm Serial Serial / OK /
no_comm MultiProcess Serial fork OK /
no_comm MultiProcess Serial forkserver failure B
no_comm MultiProcess Serial spawn failure B
classic MultiProcess MultiProcess fork failure A
classic MultiProcess MultiProcess forkserver failure B
classic MultiProcess MultiProcess spawn failure B
no_comm MultiProcess MultiProcess fork failure A
no_comm MultiProcess MultiProcess forkserver failure B
no_comm MultiProcess MultiProcess spawn failure B

I give the test script, and a sample of what I called A and B errors after.
As you can see, my workarounds only allows to use MultiprocessIterator for the train iterator (and uses fork start method, so, as you said before, it may be unsafe). Hope this will help investigation.
(by the way, the number of sub-processes by worker for MultiprocessIterator were 3)

NB: Even with a serial iterator for test dataset, I have pretty decent speed when training Imagenet on a cluster, but this doesn't solve the problem.

test_script

Click for details
#!/usr/bin/env python
from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI
import chainermn
import multiprocessing
import numpy


class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            # the size of the inputs to each layer will be inferred
            l1=L.Linear(784, n_units),  # n_in -> n_units
            l2=L.Linear(n_units, n_units),  # n_units -> n_units
            l3=L.Linear(n_units, n_out),  # n_units -> n_out
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--epoch', '-e', type=int, default=60,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--method', type=str, default='',
                        help='start methods: fork forkserver spawn')
    parser.add_argument('--scatter_no_comm', action='store_true', help='do not use chainermn builtin scatter function')
    parser.add_argument('--loaderjob', '-j', type=int, default=0, help='Number of parallel data loading processes')
    parser.add_argument('--double_buffering', action='store_true', help='improves speed')
    parser.add_argument('--shuffle_seed', type=int, default=0, help='Seed used to shuffle dataset during scattering')
    parser.add_argument('--test_mp', action='store_true', help='use MultiProcess iterator for test set when also used for train')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.
    if args.double_buffering:
        args.communicator = 'pure_nccl'

    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank

    if comm.mpi_comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
        print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num unit: {}'.format(args.unit))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    model = L.Classifier(MLP(args.unit, 10))
    if device >= 0:
        chainer.cuda.get_device(device).use()
        model.to_gpu()

    initial_lr = 0.1

    # Create a multi node optimizer from a standard Chainer optimizer.
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(lr=initial_lr, momentum=0.9), comm, double_buffering=args.double_buffering)
    optimizer.setup(model)

    if args.scatter_no_comm:
        train, test = chainer.datasets.get_mnist()

        train = scatter_dataset_no_comm(train, comm, shuffle=True, seed=args.shuffle_seed)
        test = scatter_dataset_no_comm(test, comm, shuffle=True, seed=args.shuffle_seed)

    else:
        # Split and distribute the dataset. Only worker 0 loads the whole dataset.
        # Datasets of worker 0 are evenly split and distributed to all workers.
        if comm.rank == 0:
            train, test = chainer.datasets.get_mnist()
        else:
            train = None
            test = None

        train = chainermn.scatter_dataset(train, comm, shuffle=True)
        test = chainermn.scatter_dataset(test, comm)

    if args.loaderjob == 0:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=False)
        test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                     repeat=False, shuffle=False)

    else:
        if args.method != '':
            multiprocessing.set_start_method(args.method) 
        train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize, shuffle=False, n_processes=args.loaderjob)
        if args.test_mp:
            test_iter = chainer.iterators.MultiprocessIterator(test, args.batchsize, repeat=False, shuffle=False,
                                                               n_processes=args.loaderjob)
        else:
            test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                         repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Create a multi node evaluator from a standard Chainer evaluator.
    evaluator = extensions.Evaluator(test_iter, model, device=device)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    trainer.extend(evaluator)

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        # trainer.extend(extensions.dump_graph('main/loss'))
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.observe_lr(), trigger=(1, 'epoch'))
        trainer.extend(extensions.PrintReport(
            ['epoch', 'main/loss', 'validation/main/loss',
             'main/accuracy', 'validation/main/accuracy', 'elapsed_time', 'lr']))
        trainer.extend(extensions.ProgressBar())

    trainer.run()


def scatter_dataset_no_comm(dataset, comm, shuffle=False, seed=0):
    """Scatter the given dataset to the workers in the communicator.

    This function does not use MPI communication.

    The dataset of every worker has to be the same (assuming file sharing system like nfs),
    and the seed as to be the same across all processes

    The dataset is split to sub datasets of almost equal sizes and scattered
    to workers. To create a sub dataset, ``chainer.datasets.SubDataset`` is
    used.
    Args:
        dataset: A dataset (e.g., ``list``, ``numpy.ndarray``,
            ``chainer.datasets.TupleDataset``, ...).
        comm: ChainerMN communicator or MPI4py communicator.
        shuffle (bool): If ``True``, the order of examples is shuffled
            before being scattered.
        seed (int): Seed the generator used for the permutation of indexes.
            If an integer being convertible to 32 bit unsigned integers is
            specified, it is guaranteed that each sample
            in the given dataset always belongs to a specific subset.
            If ``None``, the permutation is changed randomly.
    Returns:
        Scattered dataset.
    """

    if hasattr(comm, 'mpi_comm'):
        comm = comm.mpi_comm
    assert hasattr(comm, 'send')
    assert hasattr(comm, 'recv')

    order = None
    n_total_samples = len(dataset)
    if shuffle is not None:
        order = numpy.random.RandomState(seed).permutation(
            n_total_samples)

    n_sub_samples = (n_total_samples + comm.size - 1) // comm.size

    b = n_total_samples * comm.rank // comm.size
    e = b + n_sub_samples

    return chainer.datasets.SubDataset(dataset, b, e, order)


if __name__ == '__main__':
    main()

error A

Click for details
[abd391b31812:00880] *** Process received signal ***
[abd391b31812:00880] Signal: Segmentation fault (11)
[abd391b31812:00880] Signal code: Address not mapped (1)
[abd391b31812:00880] Failing at address: 0x2e7b468
[abd391b31812:00880] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7f9b9e123390]
[abd391b31812:00880] [ 1] python3(PyDict_GetItem+0x7e)[0x5a38de]
[abd391b31812:00880] [ 2] python3(_PyObject_GenericGetAttrWithDict+0xb8)[0x593a98]
[abd391b31812:00880] [ 3] python3(PyEval_EvalFrameEx+0x44d)[0x53712d]
[abd391b31812:00880] [ 4] python3(PyEval_EvalCodeEx+0x88a)[0x5416ea]
[abd391b31812:00880] [ 5] python3[0x4ebd23]
[abd391b31812:00880] [ 6] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [ 7] python3[0x4fb9ce]
[abd391b31812:00880] [ 8] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [ 9] python3[0x584716]
[abd391b31812:00880] [10] python3(PyEval_EvalFrameEx+0xc36)[0x537916]
[abd391b31812:00880] [11] python3(PyEval_EvalFrameEx+0x4b04)[0x53b7e4]
[abd391b31812:00880] [12] python3(PyEval_EvalCodeEx+0x88a)[0x5416ea]
[abd391b31812:00880] [13] python3[0x4ebd23]
[abd391b31812:00880] [14] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [15] python3[0x4fb9ce]
[abd391b31812:00880] [16] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [17] python3[0x584716]
[abd391b31812:00880] [18] python3(PyEval_EvalFrameEx+0xc36)[0x537916]
[abd391b31812:00880] [19] python3(PyEval_EvalCodeEx+0x13b)[0x540f9b]
[abd391b31812:00880] [20] python3[0x4ebd23]
[abd391b31812:00880] [21] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [22] python3[0x53645f]
[abd391b31812:00880] [23] python3[0x5b7994]
[abd391b31812:00880] [24] python3[0x5b7fbc]
[abd391b31812:00880] [25] python3[0x57f03c]
[abd391b31812:00880] [26] python3(PyObject_Call+0x47)[0x5c1797]
[abd391b31812:00880] [27] python3(PyEval_EvalFrameEx+0x4ec6)[0x53bba6]
[abd391b31812:00880] [28] python3(PyEval_EvalCodeEx+0x13b)[0x540f9b]
[abd391b31812:00880] [29] python3[0x4ebe37]
[abd391b31812:00880] *** End of error message ***

error B

Click for details
[97857bc4a063:01156] *** Process received signal ***
[97857bc4a063:01156] Signal: Segmentation fault (11)
[97857bc4a063:01156] Signal code: Address not mapped (1)
[97857bc4a063:01156] Failing at address: 0x28
[97857bc4a063:01156] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7fcaf906b390]
[97857bc4a063:01156] [ 1] /usr/local/lib/openmpi/mca_pmix_pmix112.so(+0x2cfaa)[0x7fcaf40aefaa]
[97857bc4a063:01156] [ 2] /usr/local/lib/libopen-pal.so.20(opal_libevent2022_event_base_loop+0x7f3)[0x7fcaf8a06c73]
[97857bc4a063:01156] [ 3] /usr/local/lib/openmpi/mca_pmix_pmix112.so(+0x2abdd)[0x7fcaf40acbdd]
[97857bc4a063:01156] [ 4] /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba)[0x7fcaf90616ba]
[97857bc4a063:01156] [ 5] /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d)[0x7fcaf8d9741d]
[97857bc4a063:01156] *** End of error message ***
[3ab9b565c75b:01170] [[27224,0],8] usock_peer_send_blocking: send() to socket 30 failed: Broken pipe (32)
[3ab9b565c75b:01170] [[27224,0],8] ORTE_ERROR_LOG: Unreachable in file oob_usock_connection.c at line 316
[3ab9b565c75b:01170] [[27224,0],8]-[[27224,1],7] usock_peer_accept: usock_peer_send_connect_ack failed
/usr/lib/python3.5/multiprocessing/semaphore_tracker.py:129: UserWarning: semaphore_tracker: There appear to be 4 leaked semaphores to clean up at shutdown
len(cache))
/usr/lib/python3.5/multiprocessing/semaphore_tracker.py:129: UserWarning: semaphore_tracker: There appear to be 4 leaked semaphores to clean up at shutdown
len(cache))
[4fadac19fc27:01182] [[27224,0],2] usock_peer_send_blocking: send() to socket 30 failed: Broken pipe (32)
[4fadac19fc27:01182] [[27224,0],2] ORTE_ERROR_LOG: Unreachable in file oob_usock_connection.c at line 316
[4fadac19fc27:01182] [[27224,0],2]-[[27224,1],1] usock_peer_accept: usock_peer_send_connect_ack failed
[f73e3c82020e:01194] [[27224,0],3] usock_peer_send_blocking: send() to socket 30 failed: Broken pipe (32)
[f73e3c82020e:01194] [[27224,0],3] ORTE_ERROR_LOG: Unreachable in file oob_usock_connection.c at line 316
[f73e3c82020e:01194] [[27224,0],3]-[[27224,1],2] usock_peer_accept: usock_peer_send_connect_ack failed
/usr/lib/python3.5/multiprocessing/semaphore_tracker.py:129: UserWarning: semaphore_tracker: There appear to be 4 leaked semaphores to clean up at shutdown
len(cache))

@keisukefukuda
Copy link
Member

Thanks for the detailed report! Now I've got a docker environment internally and try to reproduce your problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants