diff --git a/cmd/importer/main.go b/cmd/importer/main.go new file mode 100644 index 0000000000..a1baaab03c --- /dev/null +++ b/cmd/importer/main.go @@ -0,0 +1,179 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "go.uber.org/zap/zapcore" + "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/cmd/importer/pod" + "sigs.k8s.io/kueue/cmd/importer/util" + "sigs.k8s.io/kueue/pkg/util/useragent" +) + +const ( + NamespaceFlag = "namespace" + NamespaceFlagShort = "n" + QueueMappingFlag = "queuemapping" + QueueLabelFlag = "queuelabel" + QPSFlag = "qps" + BurstFlag = "burst" + VerbosityFlag = "verbose" + VerboseFlagShort = "v" +) + +var ( + k8sClient client.Client + rootCmd = &cobra.Command{ + Use: "importer", + Short: "Import existing (running) objects into Kueue", + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + v, _ := cmd.Flags().GetCount(VerbosityFlag) + level := (v + 1) * -1 + ctrl.SetLogger(zap.New( + zap.UseDevMode(true), + zap.ConsoleEncoder(), + zap.Level(zapcore.Level(level)), + )) + + kubeConfig, err := ctrl.GetConfig() + if err != nil { + return err + } + if kubeConfig.UserAgent == "" { + kubeConfig.UserAgent = useragent.Default() + } + qps, err := cmd.Flags().GetFloat32(QPSFlag) + if err != nil { + return err + } + kubeConfig.QPS = qps + bust, err := cmd.Flags().GetInt(BurstFlag) + if err != nil { + return err + } + kubeConfig.Burst = bust + + if err := kueue.AddToScheme(scheme.Scheme); err != nil { + return err + } + + c, err := client.New(kubeConfig, client.Options{Scheme: scheme.Scheme}) + if err != nil { + return err + } + k8sClient = c + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + cmd.Flags().Visit(func(f *pflag.Flag) { + fmt.Println(f.Name, f.Value.String()) + }) + err := checkCmd(cmd, args) + if err != nil { + return err + } + return nil + }, + } +) + +func init() { + rootCmd.PersistentFlags().StringSliceP(NamespaceFlag, NamespaceFlagShort, nil, "target namespaces (at least one should be provided)") + rootCmd.MarkPersistentFlagRequired(NamespaceFlag) + + rootCmd.PersistentFlags().String(QueueLabelFlag, "", "label used to identify the target local queue") + rootCmd.PersistentFlags().StringToString(QueueMappingFlag, nil, "mapping from \""+QueueLabelFlag+"\" label values to local queue names") + rootCmd.MarkPersistentFlagRequired(QueueLabelFlag) + + rootCmd.PersistentFlags().Float32(QPSFlag, 50, "client QPS, as described in https://kubernetes.io/docs/reference/config-api/apiserver-eventratelimit.v1alpha1/#eventratelimit-admission-k8s-io-v1alpha1-Limit") + rootCmd.PersistentFlags().Int(BurstFlag, 50, "client Burst, https://kubernetes.io/docs/reference/config-api/apiserver-eventratelimit.v1alpha1/#eventratelimit-admission-k8s-io-v1alpha1-Limit") + + rootCmd.PersistentFlags().CountP(VerbosityFlag, VerboseFlagShort, "verbosity (specify multiple times to increase the log level)") + + rootCmd.AddCommand( + &cobra.Command{ + Use: "check", + RunE: checkCmd, + }, + &cobra.Command{ + Use: "import", + RunE: importCmd, + }, + ) +} + +func main() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} + +func loadMappingCache(ctx context.Context, cmd *cobra.Command) (*util.MappingCache, error) { + flags := cmd.Flags() + namespaces, err := flags.GetStringSlice(NamespaceFlag) + if err != nil { + return nil, err + } + queueLabel, err := flags.GetString(QueueLabelFlag) + if err != nil { + return nil, err + } + mapping, err := flags.GetStringToString(QueueMappingFlag) + if err != nil { + return nil, err + } + return util.LoadMappingCache(ctx, k8sClient, namespaces, queueLabel, mapping) +} + +func checkCmd(cmd *cobra.Command, _ []string) error { + log := ctrl.Log.WithName("check") + ctx := ctrl.LoggerInto(context.Background(), log) + mappingCache, err := loadMappingCache(ctx, cmd) + if err != nil { + return err + } + + return pod.Check(ctx, k8sClient, mappingCache) +} + +func importCmd(cmd *cobra.Command, _ []string) error { + log := ctrl.Log.WithName("import") + ctx := ctrl.LoggerInto(context.Background(), log) + + mappingCache, err := loadMappingCache(ctx, cmd) + if err != nil { + return err + } + + if err = pod.Check(ctx, k8sClient, mappingCache); err != nil { + return err + } + + return pod.Import(ctx, k8sClient, mappingCache) +} diff --git a/cmd/importer/pod/check.go b/cmd/importer/pod/check.go new file mode 100644 index 0000000000..68be51b89b --- /dev/null +++ b/cmd/importer/pod/check.go @@ -0,0 +1,62 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "context" + "errors" + "fmt" + + corev1 "k8s.io/api/core/v1" + apimeta "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/cmd/importer/util" +) + +func Check(ctx context.Context, c client.Client, mappingCache *util.MappingCache) error { + ch := make(chan corev1.Pod) + go util.PushPods(ctx, c, mappingCache.Namespaces, mappingCache.QueueLabel, ch) + summary := util.ConcurrentProcessPod(ch, func(p *corev1.Pod) error { + log := ctrl.LoggerFrom(ctx).WithValues("pod", klog.KObj(p)) + log.V(3).Info("Checking") + + cq, err := mappingCache.ClusterQueue(p) + if err != nil { + return err + } + + if !apimeta.IsStatusConditionTrue(cq.Status.Conditions, kueue.ClusterQueueActive) { + return fmt.Errorf("%q: %w", cq.Name, util.ErrCQNotFound) + } + + // do some additional checks like: + // - (maybe) the resources managed by the queues + + return nil + }) + + fmt.Printf("Checked %d pods\n", summary.TotalPods) + fmt.Printf("Failed %d pods\n", summary.FailedPods) + for e, pods := range summary.ErrorsForPods { + fmt.Printf("%dx: %s\n\t Observed first for pod %q\n", len(pods), e, pods[0]) + } + return errors.Join(summary.Errors...) +} diff --git a/cmd/importer/pod/check_test.go b/cmd/importer/pod/check_test.go new file mode 100644 index 0000000000..f0f77a4275 --- /dev/null +++ b/cmd/importer/pod/check_test.go @@ -0,0 +1,117 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/cmd/importer/util" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + testingpod "sigs.k8s.io/kueue/pkg/util/testingjobs/pod" +) + +const ( + testingNamespace = "ns" + testingQueueLabel = "testing.lbl" +) + +func TestCheckNamespace(t *testing.T) { + basePodWrapper := testingpod.MakePod("pod", testingNamespace). + Label(testingQueueLabel, "q1") + + baseLocalQueue := utiltesting.MakeLocalQueue("lq1", testingNamespace).ClusterQueue("cq1") + baseClusterQueue := utiltesting.MakeClusterQueue("cq1") + + cases := map[string]struct { + pods []corev1.Pod + clusteQueues []kueue.ClusterQueue + localQueues []kueue.LocalQueue + mapping map[string]string + + wantError error + }{ + "empty cluster": {}, + "no mapping": { + pods: []corev1.Pod{ + *basePodWrapper.Clone().Obj(), + }, + wantError: util.ErrNoMapping, + }, + "no local queue": { + pods: []corev1.Pod{ + *basePodWrapper.Clone().Obj(), + }, + mapping: map[string]string{ + "q1": "lq1", + }, + wantError: util.ErrLQNotFound, + }, + "no cluster queue": { + pods: []corev1.Pod{ + *basePodWrapper.Clone().Obj(), + }, + mapping: map[string]string{ + "q1": "lq1", + }, + localQueues: []kueue.LocalQueue{ + *baseLocalQueue.Obj(), + }, + wantError: util.ErrCQNotFound, + }, + "all found": { + pods: []corev1.Pod{ + *basePodWrapper.Clone().Obj(), + }, + mapping: map[string]string{ + "q1": "lq1", + }, + localQueues: []kueue.LocalQueue{ + *baseLocalQueue.Obj(), + }, + clusteQueues: []kueue.ClusterQueue{ + *baseClusterQueue.Obj(), + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + podsList := corev1.PodList{Items: tc.pods} + cqList := kueue.ClusterQueueList{Items: tc.clusteQueues} + lqList := kueue.LocalQueueList{Items: tc.localQueues} + + builder := utiltesting.NewClientBuilder() + builder = builder.WithLists(&podsList, &cqList, &lqList) + + client := builder.Build() + ctx := context.Background() + + mpc, _ := util.LoadMappingCache(ctx, client, []string{testingNamespace}, testingQueueLabel, tc.mapping) + gotErr := Check(ctx, client, mpc) + + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" { + t.Errorf("Unexpected error (-want/+got)\n%s", diff) + } + }) + } +} diff --git a/cmd/importer/pod/import.go b/cmd/importer/pod/import.go new file mode 100644 index 0000000000..0e5714fa29 --- /dev/null +++ b/cmd/importer/pod/import.go @@ -0,0 +1,214 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "context" + "errors" + "fmt" + "time" + + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + apimeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/cmd/importer/util" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/pod" + "sigs.k8s.io/kueue/pkg/workload" +) + +func Import(ctx context.Context, c client.Client, mappingCache *util.MappingCache) error { + ch := make(chan corev1.Pod) + go util.PushPods(ctx, c, mappingCache.Namespaces, mappingCache.QueueLabel, ch) + summary := util.ConcurrentProcessPod(ch, func(p *corev1.Pod) error { + log := ctrl.LoggerFrom(ctx).WithValues("pod", klog.KObj(p)) + log.V(3).Info("Importing") + + lq, err := mappingCache.LocalQueue(p) + if err != nil { + return err + } + + oldLq, found := p.Labels[constants.QueueLabel] + if !found { + if err := addLabels(ctx, c, p, lq.Name); err != nil { + return fmt.Errorf("cannot add queue label: %w", err) + } + } else if oldLq != lq.Name { + return fmt.Errorf("another local queue name is set %q expecting %q", oldLq, lq.Name) + } + + kp := pod.FromObject(p) + // Note: the recorder is not used for single pods, we can just pass nil for now. + wl, err := kp.ConstructComposableWorkload(ctx, c, nil) + if err != nil { + return fmt.Errorf("construct workload: %w", err) + } + + priorityClassName, source, priority, err := jobframework.ExtractPriority(ctx, c, wl.Spec.PodSets, kp) + if err != nil { + return fmt.Errorf("getting priority info: %w", err) + } + + wl.Spec.PriorityClassName = priorityClassName + wl.Spec.Priority = &priority + wl.Spec.PriorityClassSource = source + + if err := createWorkload(ctx, c, wl); err != nil { + return fmt.Errorf("creating workload: %w", err) + + } + + //make its admission and update its status + + info := workload.NewInfo(wl) + cq := mappingCache.ClusterQueues[string(lq.Spec.ClusterQueue)] + admission := kueue.Admission{ + ClusterQueue: kueue.ClusterQueueReference(cq.Name), + PodSetAssignments: []kueue.PodSetAssignment{ + { + Name: info.TotalRequests[0].Name, + Flavors: make(map[corev1.ResourceName]kueue.ResourceFlavorReference), + ResourceUsage: info.TotalRequests[0].Requests.ToResourceList(), + Count: ptr.To[int32](1), + }, + }, + } + flv := cq.Spec.ResourceGroups[0].Flavors[0].Name + for r := range info.TotalRequests[0].Requests { + admission.PodSetAssignments[0].Flavors[r] = flv + } + + wl.Status.Admission = &admission + reservedCond := metav1.Condition{ + Type: kueue.WorkloadQuotaReserved, + Status: metav1.ConditionTrue, + Reason: "Imported", + Message: fmt.Sprintf("Imported into ClusterQueue %s", cq.Name), + } + apimeta.SetStatusCondition(&wl.Status.Conditions, reservedCond) + admittedCond := metav1.Condition{ + Type: kueue.WorkloadAdmitted, + Status: metav1.ConditionTrue, + Reason: "Imported", + Message: fmt.Sprintf("Imported into ClusterQueue %s", cq.Name), + } + apimeta.SetStatusCondition(&wl.Status.Conditions, admittedCond) + return admitWorkload(ctx, c, wl) + }) + + fmt.Printf("Checked %d pods\n", summary.TotalPods) + fmt.Printf("Failed %d pods\n", summary.FailedPods) + for e, pods := range summary.ErrorsForPods { + fmt.Printf("%dx: %s\n\t%v", len(pods), e, pods) + } + var errs []error + for _, e := range summary.Errors { + errs = append(errs, e) + } + return errors.Join(errs...) +} + +func checkError(err error) (retry, reload bool, timeout time.Duration) { + retry_seconds, retry := apierrors.SuggestsClientDelay(err) + if retry { + return true, false, time.Duration(retry_seconds) * time.Second + } + + if apierrors.IsConflict(err) { + return true, true, 0 + } + return false, false, 0 +} + +func addLabels(ctx context.Context, c client.Client, p *corev1.Pod, queue string) error { + p.Labels[constants.QueueLabel] = queue + p.Labels[pod.ManagedLabelKey] = pod.ManagedLabelValue + + err := c.Update(ctx, p) + retry, reload, timeouut := checkError(err) + + for retry { + if timeouut >= 0 { + select { + case <-ctx.Done(): + return errors.New("context canceled") + case <-time.After(timeouut): + } + } + if reload { + err = c.Get(ctx, client.ObjectKeyFromObject(p), p) + if err != nil { + retry, reload, timeouut = checkError(err) + continue + } + p.Labels[constants.QueueLabel] = queue + p.Labels[pod.ManagedLabelKey] = pod.ManagedLabelValue + } + err = c.Update(ctx, p) + retry, reload, timeouut = checkError(err) + } + return err +} + +func createWorkload(ctx context.Context, c client.Client, wl *kueue.Workload) error { + err := c.Create(ctx, wl) + if apierrors.IsAlreadyExists(err) { + return nil + } + retry, _, timeouut := checkError(err) + for retry { + if timeouut >= 0 { + select { + case <-ctx.Done(): + return errors.New("context canceled") + case <-time.After(timeouut): + } + } + err = c.Create(ctx, wl) + retry, _, timeouut = checkError(err) + } + return err +} + +func admitWorkload(ctx context.Context, c client.Client, wl *kueue.Workload) error { + err := workload.ApplyAdmissionStatus(ctx, c, wl, false) + if apierrors.IsAlreadyExists(err) { + return nil + } + retry, _, timeouut := checkError(err) + for retry { + if timeouut >= 0 { + select { + case <-ctx.Done(): + return errors.New("context canceled") + case <-time.After(timeouut): + } + } + err = workload.ApplyAdmissionStatus(ctx, c, wl, false) + retry, _, timeouut = checkError(err) + } + return err +} diff --git a/cmd/importer/pod/import_test.go b/cmd/importer/pod/import_test.go new file mode 100644 index 0000000000..c4a83a618e --- /dev/null +++ b/cmd/importer/pod/import_test.go @@ -0,0 +1,162 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/cmd/importer/util" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobs/pod" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + testingpod "sigs.k8s.io/kueue/pkg/util/testingjobs/pod" +) + +func TestImportNamespace(t *testing.T) { + basePodWrapper := testingpod.MakePod("pod", testingNamespace). + UID("pod"). + Label(testingQueueLabel, "q1"). + Image("img", nil). + Request(corev1.ResourceCPU, "1") + + baseWlWrapper := utiltesting.MakeWorkload("pod-pod-b17ab", testingNamespace). + ControllerReference(corev1.SchemeGroupVersion.WithKind("Pod"), "pod", "pod"). + Label(constants.JobUIDLabel, "pod"). + Finalizers(kueue.ResourceInUseFinalizerName). + Queue("lq1"). + PodSets(*utiltesting.MakePodSet("main", 1). + Image("img"). + Request(corev1.ResourceCPU, "1"). + Obj()). + ReserveQuota(utiltesting.MakeAdmission("cq1"). + Assignment(corev1.ResourceCPU, "f1", "1"). + Obj()). + Condition(metav1.Condition{ + Type: kueue.WorkloadQuotaReserved, + Status: metav1.ConditionTrue, + Reason: "Imported", + Message: "Imported into ClusterQueue cq1", + }). + Condition(metav1.Condition{ + Type: kueue.WorkloadAdmitted, + Status: metav1.ConditionTrue, + Reason: "Imported", + Message: "Imported into ClusterQueue cq1", + }) + + baseLocalQueue := utiltesting.MakeLocalQueue("lq1", testingNamespace).ClusterQueue("cq1") + baseClusterQueue := utiltesting.MakeClusterQueue("cq1"). + ResourceGroup( + *utiltesting.MakeFlavorQuotas("f1").Resource(corev1.ResourceCPU, "1", "0").Obj()) + + podCmpOpts := []cmp.Option{ + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "ResourceVersion"), + } + + wlCmpOpts := []cmp.Option{ + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "ResourceVersion"), + cmpopts.IgnoreFields(metav1.Condition{}, "ObservedGeneration", "LastTransitionTime"), + } + + cases := map[string]struct { + pods []corev1.Pod + clusteQueues []kueue.ClusterQueue + localQueues []kueue.LocalQueue + mapping map[string]string + + wantPods []corev1.Pod + wantWorkloads []kueue.Workload + wantError error + }{ + + "create one": { + pods: []corev1.Pod{ + *basePodWrapper.Clone().Obj(), + }, + mapping: map[string]string{ + "q1": "lq1", + }, + localQueues: []kueue.LocalQueue{ + *baseLocalQueue.Obj(), + }, + clusteQueues: []kueue.ClusterQueue{ + *baseClusterQueue.Obj(), + }, + + wantPods: []corev1.Pod{ + *basePodWrapper.Clone(). + Label(constants.QueueLabel, "lq1"). + Label(pod.ManagedLabelKey, pod.ManagedLabelValue). + Obj(), + }, + + wantWorkloads: []kueue.Workload{ + *baseWlWrapper.Clone().Obj(), + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + podsList := corev1.PodList{Items: tc.pods} + cqList := kueue.ClusterQueueList{Items: tc.clusteQueues} + lqList := kueue.LocalQueueList{Items: tc.localQueues} + + builder := utiltesting.NewClientBuilder(). + WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge}).WithStatusSubresource(&kueue.Workload{}). + WithLists(&podsList, &cqList, &lqList) + + client := builder.Build() + ctx := context.Background() + + mpc, _ := util.LoadMappingCache(ctx, client, []string{testingNamespace}, testingQueueLabel, tc.mapping) + gotErr := Import(ctx, client, mpc) + + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" { + t.Errorf("Unexpected error (-want/+got)\n%s", diff) + } + + err := client.List(ctx, &podsList) + if err != nil { + t.Errorf("Unexpected list pod error: %s", err) + } + if diff := cmp.Diff(tc.wantPods, podsList.Items, podCmpOpts...); diff != "" { + t.Errorf("Unexpected pods (-want/+got)\n%s", diff) + } + + wlList := kueue.WorkloadList{} + err = client.List(ctx, &wlList) + if err != nil { + t.Errorf("Unexpected list workloads error: %s", err) + } + if diff := cmp.Diff(tc.wantWorkloads, wlList.Items, wlCmpOpts...); diff != "" { + t.Errorf("Unexpected workloads (-want/+got)\n%s", diff) + } + + }) + } +} diff --git a/cmd/importer/util/util.go b/cmd/importer/util/util.go new file mode 100644 index 0000000000..4cccc74713 --- /dev/null +++ b/cmd/importer/util/util.go @@ -0,0 +1,205 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "context" + "errors" + "fmt" + "maps" + "slices" + "sync" + + corev1 "k8s.io/api/core/v1" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + utilslices "sigs.k8s.io/kueue/pkg/util/slices" +) + +const ( + ListLength = 100 + + ConcurentProcessors = 8 +) + +var ( + ErrNoMappingKey = errors.New("no mapping key found") + ErrNoMapping = errors.New("no mapping found") + ErrLQNotFound = errors.New("localqueue not found") + ErrCQNotFound = errors.New("clusterqueue not found") + ErrCQNotActive = errors.New("clusterqueue not Active") +) + +func ListOptions(namespace, queueLabel, continueToken string) []client.ListOption { + opts := []client.ListOption{ + client.InNamespace(namespace), + client.HasLabels{queueLabel}, + client.Limit(ListLength), + client.Continue(continueToken), + } + return opts +} + +type MappingCache struct { + Namespaces []string + QueueLabel string + Mapping map[string]string + LocalQueues map[string]map[string]*kueue.LocalQueue + ClusterQueues map[string]*kueue.ClusterQueue +} + +func LoadMappingCache(ctx context.Context, c client.Client, namespaces []string, queueLabel string, mapping map[string]string) (*MappingCache, error) { + ret := MappingCache{ + Namespaces: slices.Clone(namespaces), + QueueLabel: queueLabel, + Mapping: maps.Clone(mapping), + LocalQueues: make(map[string]map[string]*kueue.LocalQueue), + } + + // get the cluster queues + cqList := &kueue.ClusterQueueList{} + if err := c.List(ctx, cqList); err != nil { + return nil, fmt.Errorf("loading cluster queues: %w", err) + } + ret.ClusterQueues = utilslices.ToRefMap(cqList.Items, func(cq *kueue.ClusterQueue) string { return cq.Name }) + + // get the local queues + for _, ns := range namespaces { + lqList := &kueue.LocalQueueList{} + if err := c.List(ctx, lqList); err != nil { + return nil, fmt.Errorf("loading local queues in namespace %s: %w", ns, err) + } + ret.LocalQueues[ns] = utilslices.ToRefMap(lqList.Items, func(lq *kueue.LocalQueue) string { return lq.Name }) + } + return &ret, nil +} + +func (mappingCache *MappingCache) LocalQueue(p *corev1.Pod) (*kueue.LocalQueue, error) { + mappingKey, found := p.Labels[mappingCache.QueueLabel] + if !found { + return nil, ErrNoMappingKey + } + + queueName, found := mappingCache.Mapping[mappingKey] + if !found { + return nil, fmt.Errorf("%s: %w", mappingKey, ErrNoMapping) + } + + nqQueues, found := mappingCache.LocalQueues[p.Namespace] + if !found { + return nil, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) + } + + lq, found := nqQueues[queueName] + if !found { + return nil, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) + } + return lq, nil +} + +func (mappingCache *MappingCache) ClusterQueue(p *corev1.Pod) (*kueue.ClusterQueue, error) { + lq, err := mappingCache.LocalQueue(p) + if err != nil { + return nil, err + } + queueName := string(lq.Spec.ClusterQueue) + cq, found := mappingCache.ClusterQueues[queueName] + if !found { + return nil, fmt.Errorf("cluster queue: %s: %w", queueName, ErrCQNotFound) + } + return cq, nil +} + +func PushPods(ctx context.Context, c client.Client, namespaces []string, queueLabel string, ch chan<- corev1.Pod) error { + defer close(ch) + for _, ns := range namespaces { + lst := &corev1.PodList{} + log := ctrl.LoggerFrom(ctx).WithValues("namespace", ns) + log.V(3).Info("Begin pods list") + defer log.V(3).Info("End pods list") + page := 0 + for { + err := c.List(ctx, lst, ListOptions(ns, queueLabel, lst.Continue)...) + if err != nil { + log.Error(err, "list") + return fmt.Errorf("listing pods in %s, page %d: %w", ns, page, err) + } + + for _, p := range lst.Items { + // ignore deleted or in final Phase? + ch <- p + } + + page++ + if lst.Continue == "" { + log.V(2).Info("No more pods", "pages", page) + break + } + } + } + return nil +} + +type Result struct { + Pod string + Err error +} + +type ProcessSummary struct { + TotalPods int + FailedPods int + ErrorsForPods map[string][]string + Errors []error +} + +func ConcurrentProcessPod(ch <-chan corev1.Pod, f func(p *corev1.Pod) error) ProcessSummary { + wg := sync.WaitGroup{} + resultCh := make(chan Result) + + wg.Add(ConcurentProcessors) + for i := 0; i < ConcurentProcessors; i++ { + go func() { + defer wg.Done() + for pod := range ch { + err := f(&pod) + resultCh <- Result{Pod: client.ObjectKeyFromObject(&pod).String(), Err: err} + } + }() + } + go func() { + wg.Wait() + close(resultCh) + }() + + ps := ProcessSummary{ + ErrorsForPods: make(map[string][]string), + } + for result := range resultCh { + ps.TotalPods++ + if result.Err != nil { + ps.FailedPods++ + estr := result.Err.Error() + if _, found := ps.ErrorsForPods[estr]; !found { + ps.Errors = append(ps.Errors, result.Err) + } + ps.ErrorsForPods[estr] = append(ps.ErrorsForPods[estr], result.Pod) + } + } + return ps +}