diff --git a/pkg/controller/mpi_job_controller.go b/pkg/controller/mpi_job_controller.go index 0958668a..cc09eaa8 100644 --- a/pkg/controller/mpi_job_controller.go +++ b/pkg/controller/mpi_job_controller.go @@ -1534,12 +1534,13 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev case kubeflow.MPIImplementationMPICH: container.Env = append(container.Env, mpichEnvVars...) } - - container.Env = append(container.Env, - // We overwrite these environment variables so that users will not - // be mistakenly using GPU resources for launcher due to potential - // issues with scheduler/container technologies. - nvidiaDisableEnvVars...) + if !ptr.Deref(mpiJob.Spec.RunLauncherAsWorker, false) { + container.Env = append(container.Env, + // We overwrite these environment variables so that users will not + // be mistakenly using GPU resources for launcher due to potential + // issues with scheduler/container technologies. + nvidiaDisableEnvVars...) + } c.setupSSHOnPod(&podTemplate.Spec, mpiJob) // Submit a warning event if the user specifies restart policy for diff --git a/pkg/controller/mpi_job_controller_test.go b/pkg/controller/mpi_job_controller_test.go index e4d9e1a3..4e057b8d 100644 --- a/pkg/controller/mpi_job_controller_test.go +++ b/pkg/controller/mpi_job_controller_test.go @@ -1327,6 +1327,135 @@ func TestNewLauncherAndWorker(t *testing.T) { }, }, }, + "launcher-as-worker": { + job: kubeflow.MPIJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "bar", + }, + Spec: kubeflow.MPIJobSpec{ + RunLauncherAsWorker: ptr.To(true), + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*kubeflow.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{}}, + }, + }, + }, + kubeflow.MPIReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{}}, + }, + }, + }, + }, + }, + }, + wantLauncher: batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo-launcher", + Namespace: "bar", + Labels: map[string]string{ + "app": "foo", + }, + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + kubeflow.OperatorNameLabel: kubeflow.OperatorName, + kubeflow.JobNameLabel: "foo", + kubeflow.JobRoleLabel: "launcher", + }, + }, + Spec: corev1.PodSpec{ + Hostname: "foo-launcher", + Subdomain: "foo", + RestartPolicy: corev1.RestartPolicyOnFailure, + Containers: []corev1.Container{ + { + Env: joinEnvVars( + launcherEnvVars, + ompiEnvVars, + corev1.EnvVar{Name: openMPISlotsEnv, Value: "1"}, + ), + VolumeMounts: []corev1.VolumeMount{ + {Name: "ssh-auth", MountPath: "/root/.ssh"}, + {Name: "mpi-job-config", MountPath: "/etc/mpi"}, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "ssh-auth", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + DefaultMode: newInt32(0600), + SecretName: "foo-ssh", + Items: sshVolumeItems, + }, + }, + }, + { + Name: "mpi-job-config", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "foo-config", + }, + Items: configVolumeItems, + }, + }, + }, + }, + }, + }, + }, + }, + wantWorker: corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo-worker-0", + Namespace: "bar", + Labels: map[string]string{ + kubeflow.OperatorNameLabel: kubeflow.OperatorName, + kubeflow.JobNameLabel: "foo", + kubeflow.JobRoleLabel: "worker", + kubeflow.ReplicaIndexLabel: "0", + }, + }, + Spec: corev1.PodSpec{ + Hostname: "foo-worker-0", + Subdomain: "foo", + DNSConfig: &corev1.PodDNSConfig{ + Searches: []string{"foo.bar.svc.cluster.local"}, + }, + RestartPolicy: corev1.RestartPolicyNever, + Containers: []corev1.Container{ + { + Command: []string{"/usr/sbin/sshd", "-De"}, + VolumeMounts: []corev1.VolumeMount{ + {Name: "ssh-auth", MountPath: "/root/.ssh"}, + }, + Env: workerEnvVars, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "ssh-auth", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + DefaultMode: newInt32(0600), + SecretName: "foo-ssh", + Items: sshVolumeItems, + }, + }, + }, + }, + }, + }, + }, "overrides": { job: kubeflow.MPIJob{ ObjectMeta: metav1.ObjectMeta{