Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unconfigured TaskSchedule.getDevice() and TornadoCoreRuntime.getDefaultDevice() are different #170

Open
vsilaev opened this issue Dec 24, 2021 · 1 comment

Comments

@vsilaev
Copy link
Contributor

vsilaev commented Dec 24, 2021

TornadoCoreRuntime.getDefaultDevice() return always default device of the driver 0.

The newly created TaskSchedule without any per-task configuration returns for getDevice() a device that is configured by bin/sdk/etc.tornado.properties file (properties tornado.driver and tornado.device). Obviously, the result may be different from what TornadoCoreRuntime.getDefaultDevice() returns and this could be quite confusing.

Moreover, this is the source of the issue Failed tests if PTX device is selected as default and both OpenCL and PTX drivers are available and requires usage of the additional command-line property -J"-Dtornado.unittests.device=2:0" to circumvent the problem. However, the command-line property helps with unittests but not with examples shipped.

Suggestion: TornadoCoreRuntime.getDefaultDevice() should return same device as unconfigured TaskSchedule, i.e. the one defined by bin/sdk/etc.tornado.properties file

@vsilaev
Copy link
Contributor Author

vsilaev commented Jan 12, 2023

I've fixed the issue locally via modifying TornadoTestBase
Changes are:

  1. Separate @BeforeClass and @Before behavior -- you have mixed both under @Before
  2. Changed default device/driver to take on account properties from etc/tornado.properties
    protected Tuple2<Integer, Integer> getDriverAndDeviceIndex() {
        /** WAS:
        String driverAndDevice = System.getProperty("tornado.unittests.device", "0:0");
        **/
        /** NEW **/
        String defaultDeviceAndDriver = TornadoRuntime.getProperty("tornado.driver", "0") + ":" + TornadoRuntime.getProperty("tornado.device", "0");
        String driverAndDevice = System.getProperty("tornado.unittests.device", defaultDeviceAndDriver);

        String[] propertyValues = driverAndDevice.split(":");
        return new Tuple2<>(Integer.parseInt(propertyValues[0]), Integer.parseInt(propertyValues[1]));
    }
  1. Fix error when applying driver index / device (previously in @Before)
    @Before
    public void before() {
        for (int i = 0; i < TornadoRuntime.getTornadoRuntime().getNumDrivers(); i++) {
            final TornadoDriver driver = TornadoRuntime.getTornadoRuntime().getDriver(i);
            for (int j = 0; j < driver.getDeviceCount(); j++) {
                driver.getDevice(j).reset();
            }
        }

        /*
         * Virtual Device execution assumes an environment with a single device.
         * Therefore, there is no need to change the device even if a different device
         * is set through the 'tornado.unittests.device' property
         */
        if (!wasDeviceInspected && !getVirtualDeviceEnabled()) {
            Tuple2<Integer, Integer> pairDriverDevice = getDriverAndDeviceIndex();
            int driverIndex = pairDriverDevice.f0();
            if (driverIndex != 0) {
                // We swap the default driver for the selected one
                TornadoRuntime.getTornadoRuntime().setDefaultDriver(driverIndex);
            }
            int deviceIndex = pairDriverDevice.f1();
            if (deviceIndex != 0) {
                // We swap the default device for the selected one
                // !!!!!!!!
                TornadoDriver driver = TornadoRuntime.getTornadoRuntime().getDriver(driverIndex /** WAS 0 **/ );
                driver.setDefaultDevice(deviceIndex);
            }
            wasDeviceInspected = true;
        }
    }

The full modified source is below:

/*
 * Copyright (c) 2013-2020, 2022 APT Group, Department of Computer Science,
 * The University of Manchester.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package uk.ac.manchester.tornado.unittests.common;

import org.junit.Before;
import org.junit.BeforeClass;

import uk.ac.manchester.tornado.api.TornadoDriver;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
import uk.ac.manchester.tornado.unittests.tools.TornadoHelper;

public abstract class TornadoTestBase {

    protected static boolean wasDeviceInspected = false;

    @BeforeClass
    public static void setup() {
        /*
         * Virtual Device execution assumes an environment with a single device.
         * Therefore, there is no need to change the device even if a different device
         * is set through the 'tornado.unittests.device' property
         */
        if (!wasDeviceInspected && !getVirtualDeviceEnabled()) {
            Tuple2<Integer, Integer> pairDriverDevice = getDriverAndDeviceIndex();
            int driverIndex = pairDriverDevice.f0();
            if (driverIndex != 0) {
                // We swap the default driver for the selected one
                TornadoRuntime.getTornadoRuntime().setDefaultDriver(driverIndex);
            }
            int deviceIndex = pairDriverDevice.f1();
            if (deviceIndex != 0) {
                // We swap the default device for the selected one
                TornadoDriver driver = TornadoRuntime.getTornadoRuntime().getDriver(driverIndex);
                driver.setDefaultDevice(deviceIndex);
            }
            wasDeviceInspected = true;
        }
    }

    @Before
    public void before() {
        for (int i = 0; i < TornadoRuntime.getTornadoRuntime().getNumDrivers(); i++) {
            final TornadoDriver driver = TornadoRuntime.getTornadoRuntime().getDriver(i);
            for (int j = 0; j < driver.getDeviceCount(); j++) {
                driver.getDevice(j).reset();
            }
        }
    }

    private static boolean getVirtualDeviceEnabled() {
        return Boolean.parseBoolean(System.getProperty("tornado.virtual.device", "False"));
    }

    protected static Tuple2<Integer, Integer> getDriverAndDeviceIndex() {
        String defaultDeviceAndDriver = TornadoRuntime.getProperty("tornado.driver", "0") + ":" + TornadoRuntime.getProperty("tornado.device", "0");
        String driverAndDevice = System.getProperty("tornado.unittests.device", defaultDeviceAndDriver);
        String[] propertyValues = driverAndDevice.split(":");
        return new Tuple2<>(Integer.parseInt(propertyValues[0]), Integer.parseInt(propertyValues[1]));
    }

    public void assertNotBackend(TornadoVMBackendType backend) {
        int driverIndex = TornadoRuntime.getTornadoRuntime().getDefaultDevice().getDriverIndex();
        if (TornadoRuntime.getTornadoRuntime().getBackendType(driverIndex) == backend) {
            switch (backend) {
                case PTX:
                    throw new TornadoVMPTXNotSupported("Test not supported for the PTX backend");
                case OPENCL:
                    throw new TornadoVMOpenCLNotSupported("Test not supported for the OpenCL backend");
                case SPIRV:
                    throw new TornadoVMSPIRVNotSupported("Test not supported for the SPIR-V backend");
                    
                case JAVA:
                case VIRTUAL:
                default:
                    return;
            }
        }
    }

    public void assertNotBackendOptimization(TornadoVMBackendType backend) {
        if (!TornadoHelper.OPTIMIZE_LOAD_STORE_SPIRV) {
            return;
        }
        int driverIndex = TornadoRuntime.getTornadoRuntime().getDefaultDevice().getDriverIndex();
        if (TornadoRuntime.getTornadoRuntime().getBackendType(driverIndex) == backend) {
            if (backend == TornadoVMBackendType.SPIRV) {
                throw new SPIRVOptNotSupported("Test not supported for the optimized SPIR-V BACKEND");
            }
        }
    }

    private void assertIfNeeded(TornadoDevice device, int driverIndex) {
        TornadoVMBackendType backendType = TornadoRuntime.getTornadoRuntime().getDriver(driverIndex).getBackendType();
        if (backendType != TornadoVMBackendType.OPENCL || !device.isSPIRVSupported()) {
            assertNotBackend(TornadoVMBackendType.OPENCL);
        }
    }

    protected TornadoDevice checkSPIRVSupport() {
        TornadoDevice device = null;
        TornadoTestBase.Tuple2<Integer, Integer> driverAndDeviceIndex = getDriverAndDeviceIndex();
        if (driverAndDeviceIndex.f0() != 0) {
            // If another device has been selected for testing, TornadoVM has swapped with
            // another device using the position 0 for the selected device. In this case, we
            // select the chosen device instead of looking for a device with SPIRV support.
            device = TornadoRuntime.getTornadoRuntime().getDriver(0).getDevice(0);
            assertIfNeeded(device, 0);
        } else {
            // Check if SPIRV is supported. We search for a suitable device to run on
            int numDrivers = TornadoRuntime.getTornadoRuntime().getNumDrivers();
            for (int driverIndex = 0; driverIndex < numDrivers; driverIndex++) {
                if (TornadoRuntime.getTornadoRuntime().getDriver(driverIndex).getBackendType() != TornadoVMBackendType.PTX) {
                    int maxDevices = TornadoRuntime.getTornadoRuntime().getDriver(driverIndex).getDeviceCount();
                    for (int i = 0; i < maxDevices; i++) {
                        // Search for the device with SPIRV Support
                        device = TornadoRuntime.getTornadoRuntime().getDriver(driverIndex).getDevice(i);
                        if (device.isSPIRVSupported()) {
                            return device;
                        }
                    }
                }
            }
        }
        return device;
    }

    protected static class Tuple2<T0, T1> {
        T0 t0;
        T1 t1;

        public Tuple2(T0 first, T1 second) {
            this.t0 = first;
            this.t1 = second;
        }

        public T0 f0() {
            return t0;
        }

        public T1 f1() {
            return t1;
        }
    }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant