From 1fecb128f501881dbe7036ef6735a1e05ecabe94 Mon Sep 17 00:00:00 2001 From: Traian Schiau <55734665+trasc@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:11:37 +0200 Subject: [PATCH] [importer] Extend mapping capabilities. (#1878) * [cmd/importer] Advanced mapping. * [cmd/importer] Add skip. * Review Remarks --- cmd/importer/README.md | 61 ++++++++- cmd/importer/main.go | 25 ++-- cmd/importer/pod/check.go | 31 +++-- cmd/importer/pod/check_test.go | 52 ++++++-- cmd/importer/pod/import.go | 25 ++-- cmd/importer/pod/import_test.go | 28 +++- cmd/importer/util/util.go | 142 ++++++++++++++------ cmd/importer/util/util_test.go | 143 +++++++++++++++++++++ go.mod | 4 +- test/integration/importer/importer_test.go | 2 +- 10 files changed, 415 insertions(+), 98 deletions(-) create mode 100644 cmd/importer/util/util_test.go diff --git a/cmd/importer/README.md b/cmd/importer/README.md index 5667392e5e..4e984fc0bc 100644 --- a/cmd/importer/README.md +++ b/cmd/importer/README.md @@ -20,14 +20,47 @@ go build -C cmd/importer/ -o $(pwd)/bin/importer The command runs against the systems default kubectl configuration. Check the [kubectl documentation](https://kubernetes.io/docs/tasks/access-application-cluster/configure-access-multiple-clusters/) to learn more about how to Configure Access to Multiple Clusters. +### Check The importer will perform following checks: - At least one `namespace` is provided. -- The label key (`queuelabel`) providing the queue mapping is provided. -- A mapping from one of the encountered `queuelabel` values to an existing LocalQueue exists. +- For every Pod a mapping to a LocalQueue is available. +- The target LocalQueue exists. - The LocalQueues involved in the import are using an existing ClusterQueue. - The ClusterQueues involved have at least one ResourceGroup using an existing ResourceFlavor. This ResourceFlavor is used when the importer creates the admission for the created workloads. +The are two ways the mapping from a pod to a LocalQueue can be specified: + +#### Simple mapping + +It's done by specifying a label name and any number of = a command line arguments eg. `--queuelabel=src.lbl --queuemapping=src-val=user-queue,src-val2=user-queue2`. + +#### Advanced mapping + +It's done providing an yaml mapping file name as `--queuemapping-file` argument, it's expected content being: + +```yaml +- match: + labels: + src.lbl: src-val + toLocalQueue: user-queue +- match: + priorityClassName: p-class + labels: + src.lbl: src-val2 + src2.lbl: src2-val + toLocalQueue: user-queue2 +- match: + labels: + src.lbl: src-val3 + skip: true +``` + +- During the mapping, if the match rule has no `priorityClassName` the `priorityClassName` of the pod will be ignored, if more than one `label: value` pairs are provided, all of them should match. +- The rules are evaluated in order. +- `skip: true` can be used to ignore the pods matching a rule. + +### Import After which, if `--dry-run=false` was specified, for each selected Pod the importer will: - Update the Pod's Kueue related labels. @@ -36,7 +69,31 @@ After which, if `--dry-run=false` was specified, for each selected Pod the impor ### Example +#### Simple mapping + ```bash ./bin/importer import -n ns1,ns2 --queuelabel=src.lbl --queuemapping=src-val=user-queue,src-val2=user-queue2 --dry-run=false ``` Will import all the pods in namespace `ns1` or `ns2` having the label `src.lbl` set in LocalQueues `user-queue` or `user-queue2` depending on `src.lbl` value. + + #### Advanced mapping + With mapping file: +```yaml +- match: + labels: + src.lbl: src-val + toLocalQueue: user-queue +- match: + priorityClassName: p-class + labels: + src.lbl: src-val2 + src2.lbl: src2-val + toLocalQueue: user-queue2 +``` + +```bash +./bin/importer import -n ns1,ns2 --queuemapping-file= --dry-run=false +``` + + Will import all the pods in namespace `ns1` or `ns2` having the label `src.lbl` set to `src-val` in LocalQueue `user-queue` regardless of their priorityClassName and those with `src.lbl==src-val2` ,`src2.lbl==src2-val` and `priorityClassName==p-class`in `user-queue2`. + diff --git a/cmd/importer/main.go b/cmd/importer/main.go index ef44dd2258..9544d1a420 100644 --- a/cmd/importer/main.go +++ b/cmd/importer/main.go @@ -20,12 +20,10 @@ import ( "context" "errors" "fmt" - "maps" "os" "github.com/spf13/cobra" "go.uber.org/zap/zapcore" - "gopkg.in/yaml.v2" "k8s.io/apimachinery/pkg/util/validation" "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" @@ -82,8 +80,10 @@ func setFlags(cmd *cobra.Command) { cmd.Flags().UintP(ConcurrencyFlag, ConcurrencyFlagShort, 8, "number of concurrent import workers") cmd.Flags().Bool(DryRunFlag, true, "don't import, check the config only") - _ = cmd.MarkFlagRequired(QueueLabelFlag) _ = cmd.MarkFlagRequired(NamespaceFlag) + cmd.MarkFlagsRequiredTogether(QueueLabelFlag, QueueMappingFlag) + cmd.MarkFlagsOneRequired(QueueLabelFlag, QueueMappingFileFlag) + cmd.MarkFlagsMutuallyExclusive(QueueLabelFlag, QueueMappingFileFlag) } func init() { @@ -116,12 +116,8 @@ func loadMappingCache(ctx context.Context, c client.Client, cmd *cobra.Command) if err != nil { return nil, err } - queueLabel, err := flags.GetString(QueueLabelFlag) - if err != nil { - return nil, err - } - mapping, err := flags.GetStringToString(QueueMappingFlag) + queueLabel, err := flags.GetString(QueueLabelFlag) if err != nil { return nil, err } @@ -131,17 +127,18 @@ func loadMappingCache(ctx context.Context, c client.Client, cmd *cobra.Command) return nil, err } + var mapping util.MappingRules if mappingFile != "" { - yamlFile, err := os.ReadFile(mappingFile) + mapping, err = util.MappingRulesFromFile(mappingFile) if err != nil { return nil, err } - extraMapping := map[string]string{} - err = yaml.Unmarshal(yamlFile, extraMapping) + } else { + queueLabelMapping, err := flags.GetStringToString(QueueMappingFlag) if err != nil { - return nil, fmt.Errorf("decoding %q: %w", mappingFile, err) + return nil, err } - maps.Copy(mapping, extraMapping) + mapping = util.MappingRulesForLabel(queueLabel, queueLabelMapping) } addLabels, err := flags.GetStringToString(AddLabelsFlag) @@ -162,7 +159,7 @@ func loadMappingCache(ctx context.Context, c client.Client, cmd *cobra.Command) return nil, fmt.Errorf("%s: %w", AddLabelsFlag, errors.Join(validationErrors...)) } - return util.LoadImportCache(ctx, c, namespaces, queueLabel, mapping, addLabels) + return util.LoadImportCache(ctx, c, namespaces, mapping, addLabels) } func getKubeClient(cmd *cobra.Command) (client.Client, error) { diff --git a/cmd/importer/pod/check.go b/cmd/importer/pod/check.go index e782d413bd..380a9771a5 100644 --- a/cmd/importer/pod/check.go +++ b/cmd/importer/pod/check.go @@ -32,40 +32,49 @@ import ( func Check(ctx context.Context, c client.Client, cache *util.ImportCache, jobs uint) error { ch := make(chan corev1.Pod) go func() { - err := util.PushPods(ctx, c, cache.Namespaces, cache.QueueLabel, ch) + err := util.PushPods(ctx, c, cache.Namespaces, ch) if err != nil { ctrl.LoggerFrom(ctx).Error(err, "Listing pods") } }() - summary := util.ConcurrentProcessPod(ch, jobs, func(p *corev1.Pod) error { + summary := util.ConcurrentProcessPod(ch, jobs, func(p *corev1.Pod) (bool, error) { log := ctrl.LoggerFrom(ctx).WithValues("pod", klog.KObj(p)) log.V(3).Info("Checking") - cq, err := cache.ClusterQueue(p) - if err != nil { - return err + cq, skip, err := cache.ClusterQueue(p) + if skip || err != nil { + return skip, err } if len(cq.Spec.ResourceGroups) == 0 { - return fmt.Errorf("%q has no resource groups: %w", cq.Name, util.ErrCQInvalid) + return false, fmt.Errorf("%q has no resource groups: %w", cq.Name, util.ErrCQInvalid) } if len(cq.Spec.ResourceGroups[0].Flavors) == 0 { - return fmt.Errorf("%q has no resource groups flavors: %w", cq.Name, util.ErrCQInvalid) + return false, fmt.Errorf("%q has no resource groups flavors: %w", cq.Name, util.ErrCQInvalid) } rfName := string(cq.Spec.ResourceGroups[0].Flavors[0].Name) rf, rfFound := cache.ResourceFalvors[rfName] if !rfFound { - return fmt.Errorf("%q flavor %q: %w", cq.Name, rfName, util.ErrCQInvalid) + return false, fmt.Errorf("%q flavor %q: %w", cq.Name, rfName, util.ErrCQInvalid) + } + + var pv int32 + if pc, found := cache.PriorityClasses[p.Spec.PriorityClassName]; found { + pv = pc.Value + } else { + if p.Spec.PriorityClassName != "" { + return false, fmt.Errorf("%q: %w", p.Spec.PriorityClassName, util.ErrPCNotFound) + } } - log.V(2).Info("Successfully checked", "pod", klog.KObj(p), "clusterQueue", klog.KObj(cq), "resourceFalvor", klog.KObj(rf)) - return nil + log.V(2).Info("Successfully checked", "clusterQueue", klog.KObj(cq), "resourceFalvor", klog.KObj(rf), "priority", pv) + return false, nil }) log := ctrl.LoggerFrom(ctx) - log.Info("Check done", "checked", summary.TotalPods, "failed", summary.FailedPods) + log.Info("Check done", "checked", summary.TotalPods, "skipped", summary.SkippedPods, "failed", summary.FailedPods) for e, pods := range summary.ErrorsForPods { log.Info("Validation failed for Pods", "err", e, "occurrences", len(pods), "obsevedFirstIn", pods[0]) } diff --git a/cmd/importer/pod/check_test.go b/cmd/importer/pod/check_test.go index d0e4e39c21..8c976ddef2 100644 --- a/cmd/importer/pod/check_test.go +++ b/cmd/importer/pod/check_test.go @@ -46,7 +46,7 @@ func TestCheckNamespace(t *testing.T) { pods []corev1.Pod clusteQueues []kueue.ClusterQueue localQueues []kueue.LocalQueue - mapping map[string]string + mapping util.MappingRules flavors []kueue.ResourceFlavor wantError error @@ -62,8 +62,16 @@ func TestCheckNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, wantError: util.ErrLQNotFound, }, @@ -71,8 +79,16 @@ func TestCheckNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, localQueues: []kueue.LocalQueue{ *baseLocalQueue.Obj(), @@ -83,8 +99,16 @@ func TestCheckNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, localQueues: []kueue.LocalQueue{ *baseLocalQueue.Obj(), @@ -98,8 +122,16 @@ func TestCheckNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, localQueues: []kueue.LocalQueue{ *baseLocalQueue.Obj(), @@ -126,7 +158,7 @@ func TestCheckNamespace(t *testing.T) { client := builder.Build() ctx := context.Background() - mpc, _ := util.LoadImportCache(ctx, client, []string{testingNamespace}, testingQueueLabel, tc.mapping, nil) + mpc, _ := util.LoadImportCache(ctx, client, []string{testingNamespace}, tc.mapping, nil) gotErr := Check(ctx, client, mpc, 8) if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" { diff --git a/cmd/importer/pod/import.go b/cmd/importer/pod/import.go index 06f5f73f7c..10b0c3e950 100644 --- a/cmd/importer/pod/import.go +++ b/cmd/importer/pod/import.go @@ -43,34 +43,34 @@ import ( func Import(ctx context.Context, c client.Client, cache *util.ImportCache, jobs uint) error { ch := make(chan corev1.Pod) go func() { - err := util.PushPods(ctx, c, cache.Namespaces, cache.QueueLabel, ch) + err := util.PushPods(ctx, c, cache.Namespaces, ch) if err != nil { ctrl.LoggerFrom(ctx).Error(err, "Listing pods") } }() - summary := util.ConcurrentProcessPod(ch, jobs, func(p *corev1.Pod) error { + summary := util.ConcurrentProcessPod(ch, jobs, func(p *corev1.Pod) (bool, error) { log := ctrl.LoggerFrom(ctx).WithValues("pod", klog.KObj(p)) log.V(3).Info("Importing") - lq, err := cache.LocalQueue(p) - if err != nil { - return err + lq, skip, err := cache.LocalQueue(p) + if skip || err != nil { + return skip, err } oldLq, found := p.Labels[controllerconstants.QueueLabel] if !found { if err := addLabels(ctx, c, p, lq.Name, cache.AddLabels); err != nil { - return fmt.Errorf("cannot add queue label: %w", err) + return false, fmt.Errorf("cannot add queue label: %w", err) } } else if oldLq != lq.Name { - return fmt.Errorf("another local queue name is set %q expecting %q", oldLq, lq.Name) + return false, fmt.Errorf("another local queue name is set %q expecting %q", oldLq, lq.Name) } kp := pod.FromObject(p) // Note: the recorder is not used for single pods, we can just pass nil for now. wl, err := kp.ConstructComposableWorkload(ctx, c, nil) if err != nil { - return fmt.Errorf("construct workload: %w", err) + return false, fmt.Errorf("construct workload: %w", err) } maps.Copy(wl.Labels, cache.AddLabels) @@ -82,12 +82,11 @@ func Import(ctx context.Context, c client.Client, cache *util.ImportCache, jobs } if err := createWorkload(ctx, c, wl); err != nil { - return fmt.Errorf("creating workload: %w", err) + return false, fmt.Errorf("creating workload: %w", err) } //make its admission and update its status - info := workload.NewInfo(wl) cq := cache.ClusterQueues[string(lq.Spec.ClusterQueue)] admission := kueue.Admission{ @@ -122,14 +121,14 @@ func Import(ctx context.Context, c client.Client, cache *util.ImportCache, jobs } apimeta.SetStatusCondition(&wl.Status.Conditions, admittedCond) if err := admitWorkload(ctx, c, wl); err != nil { - return err + return false, err } log.V(2).Info("Successfully imported", "pod", klog.KObj(p), "workload", klog.KObj(wl)) - return nil + return false, nil }) log := ctrl.LoggerFrom(ctx) - log.Info("Import done", "checked", summary.TotalPods, "failed", summary.FailedPods) + log.Info("Import done", "checked", summary.TotalPods, "skipped", summary.SkippedPods, "failed", summary.FailedPods) for e, pods := range summary.ErrorsForPods { log.Info("Import failed for Pods", "err", e, "occurrences", len(pods), "obsevedFirstIn", pods[0]) } diff --git a/cmd/importer/pod/import_test.go b/cmd/importer/pod/import_test.go index e16970fb21..5246a9719d 100644 --- a/cmd/importer/pod/import_test.go +++ b/cmd/importer/pod/import_test.go @@ -86,7 +86,7 @@ func TestImportNamespace(t *testing.T) { pods []corev1.Pod clusteQueues []kueue.ClusterQueue localQueues []kueue.LocalQueue - mapping map[string]string + mapping util.MappingRules addLabels map[string]string wantPods []corev1.Pod @@ -98,8 +98,16 @@ func TestImportNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, localQueues: []kueue.LocalQueue{ *baseLocalQueue.Obj(), @@ -123,8 +131,16 @@ func TestImportNamespace(t *testing.T) { pods: []corev1.Pod{ *basePodWrapper.Clone().Obj(), }, - mapping: map[string]string{ - "q1": "lq1", + mapping: util.MappingRules{ + util.MappingRule{ + Match: util.MappingMatch{ + PriorityClassName: "", + Labels: map[string]string{ + testingQueueLabel: "q1", + }, + }, + ToLocalQueue: "lq1", + }, }, localQueues: []kueue.LocalQueue{ *baseLocalQueue.Obj(), @@ -165,7 +181,7 @@ func TestImportNamespace(t *testing.T) { client := builder.Build() ctx := context.Background() - mpc, _ := util.LoadImportCache(ctx, client, []string{testingNamespace}, testingQueueLabel, tc.mapping, tc.addLabels) + mpc, _ := util.LoadImportCache(ctx, client, []string{testingNamespace}, tc.mapping, tc.addLabels) gotErr := Import(ctx, client, mpc, 8) if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" { diff --git a/cmd/importer/util/util.go b/cmd/importer/util/util.go index ffb4472a4e..55dc02a678 100644 --- a/cmd/importer/util/util.go +++ b/cmd/importer/util/util.go @@ -20,7 +20,7 @@ import ( "context" "errors" "fmt" - "maps" + "os" "slices" "sync" @@ -29,6 +29,7 @@ import ( "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/yaml" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" utilslices "sigs.k8s.io/kueue/pkg/util/slices" @@ -39,27 +40,26 @@ const ( ) var ( - ErrNoMappingKey = errors.New("no mapping key found") - ErrNoMapping = errors.New("no mapping found") - ErrLQNotFound = errors.New("localqueue not found") - ErrCQNotFound = errors.New("clusterqueue not found") - ErrCQInvalid = errors.New("clusterqueue invalid") + ErrNoMapping = errors.New("no mapping found") + ErrLQNotFound = errors.New("localqueue not found") + ErrCQNotFound = errors.New("clusterqueue not found") + ErrCQInvalid = errors.New("clusterqueue invalid") + ErrPCNotFound = errors.New("priorityclass not found") ) -func listOptions(namespace, queueLabel, continueToken string) []client.ListOption { +func listOptions(namespace, continueToken string) []client.ListOption { opts := []client.ListOption{ client.InNamespace(namespace), - client.HasLabels{queueLabel}, client.Limit(ListLength), client.Continue(continueToken), } + return opts } type ImportCache struct { Namespaces []string - QueueLabel string - Mapping map[string]string + MappingRules MappingRules LocalQueues map[string]map[string]*kueue.LocalQueue ClusterQueues map[string]*kueue.ClusterQueue ResourceFalvors map[string]*kueue.ResourceFlavor @@ -67,13 +67,74 @@ type ImportCache struct { AddLabels map[string]string } -func LoadImportCache(ctx context.Context, c client.Client, namespaces []string, queueLabel string, mapping, addLabels map[string]string) (*ImportCache, error) { +type MappingMatch struct { + PriorityClassName string `json:"priorityClassName"` + Labels map[string]string `json:"labels"` +} + +func (mm *MappingMatch) Match(priorityClassName string, labels map[string]string) bool { + if mm.PriorityClassName != "" && priorityClassName != mm.PriorityClassName { + return false + } + for l, lv := range mm.Labels { + if labels[l] != lv { + return false + } + } + return true +} + +type MappingRule struct { + Match MappingMatch `json:"match"` + ToLocalQueue string `json:"toLocalQueue"` + Skip bool `json:"skip"` +} + +type MappingRules []MappingRule + +func (mr MappingRules) QueueFor(priorityClassName string, labels map[string]string) (string, bool, bool) { + for i := range mr { + if mr[i].Match.Match(priorityClassName, labels) { + return mr[i].ToLocalQueue, mr[i].Skip, true + } + } + return "", false, false +} + +func MappingRulesFromFile(mappingFile string) (MappingRules, error) { + ret := MappingRules{} + yamlFile, err := os.ReadFile(mappingFile) + if err != nil { + return nil, err + } + err = yaml.Unmarshal(yamlFile, &ret) + if err != nil { + return nil, fmt.Errorf("decoding %q: %w", mappingFile, err) + } + return ret, nil +} + +func MappingRulesForLabel(label string, m map[string]string) MappingRules { + ret := make(MappingRules, 0, len(m)) + for labelValue, queue := range m { + ret = append(ret, MappingRule{ + Match: MappingMatch{ + Labels: map[string]string{ + label: labelValue, + }, + }, + ToLocalQueue: queue, + }) + } + return ret +} + +func LoadImportCache(ctx context.Context, c client.Client, namespaces []string, mappingRules MappingRules, addLabels map[string]string) (*ImportCache, error) { ret := ImportCache{ - Namespaces: slices.Clone(namespaces), - QueueLabel: queueLabel, - Mapping: maps.Clone(mapping), - LocalQueues: make(map[string]map[string]*kueue.LocalQueue), - AddLabels: addLabels, + Namespaces: slices.Clone(namespaces), + MappingRules: mappingRules, + LocalQueues: make(map[string]map[string]*kueue.LocalQueue), + AddLabels: addLabels, } // get the cluster queues @@ -108,43 +169,42 @@ func LoadImportCache(ctx context.Context, c client.Client, namespaces []string, return &ret, nil } -func (mappingCache *ImportCache) LocalQueue(p *corev1.Pod) (*kueue.LocalQueue, error) { - mappingKey, found := p.Labels[mappingCache.QueueLabel] +func (mappingCache *ImportCache) LocalQueue(p *corev1.Pod) (*kueue.LocalQueue, bool, error) { + queueName, skip, found := mappingCache.MappingRules.QueueFor(p.Spec.PriorityClassName, p.Labels) if !found { - return nil, ErrNoMappingKey + return nil, false, ErrNoMapping } - queueName, found := mappingCache.Mapping[mappingKey] - if !found { - return nil, fmt.Errorf("%s: %w", mappingKey, ErrNoMapping) + if skip { + return nil, true, nil } nqQueues, found := mappingCache.LocalQueues[p.Namespace] if !found { - return nil, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) + return nil, false, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) } lq, found := nqQueues[queueName] if !found { - return nil, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) + return nil, false, fmt.Errorf("%s: %w", queueName, ErrLQNotFound) } - return lq, nil + return lq, false, nil } -func (mappingCache *ImportCache) ClusterQueue(p *corev1.Pod) (*kueue.ClusterQueue, error) { - lq, err := mappingCache.LocalQueue(p) - if err != nil { - return nil, err +func (mappingCache *ImportCache) ClusterQueue(p *corev1.Pod) (*kueue.ClusterQueue, bool, error) { + lq, skip, err := mappingCache.LocalQueue(p) + if skip || err != nil { + return nil, skip, err } queueName := string(lq.Spec.ClusterQueue) cq, found := mappingCache.ClusterQueues[queueName] if !found { - return nil, fmt.Errorf("cluster queue: %s: %w", queueName, ErrCQNotFound) + return nil, false, fmt.Errorf("cluster queue: %s: %w", queueName, ErrCQNotFound) } - return cq, nil + return cq, false, nil } -func PushPods(ctx context.Context, c client.Client, namespaces []string, queueLabel string, ch chan<- corev1.Pod) error { +func PushPods(ctx context.Context, c client.Client, namespaces []string, ch chan<- corev1.Pod) error { defer close(ch) for _, ns := range namespaces { lst := &corev1.PodList{} @@ -153,7 +213,7 @@ func PushPods(ctx context.Context, c client.Client, namespaces []string, queueLa defer log.V(3).Info("End pods list") page := 0 for { - err := c.List(ctx, lst, listOptions(ns, queueLabel, lst.Continue)...) + err := c.List(ctx, lst, listOptions(ns, lst.Continue)...) if err != nil { log.Error(err, "list") return fmt.Errorf("listing pods in %s, page %d: %w", ns, page, err) @@ -166,7 +226,6 @@ func PushPods(ctx context.Context, c client.Client, namespaces []string, queueLa ch <- p } } - page++ if lst.Continue == "" { log.V(2).Info("No more pods", "pages", page) @@ -178,18 +237,20 @@ func PushPods(ctx context.Context, c client.Client, namespaces []string, queueLa } type Result struct { - Pod string - Err error + Pod string + Err error + Skip bool } type ProcessSummary struct { TotalPods int + SkippedPods int FailedPods int ErrorsForPods map[string][]string Errors []error } -func ConcurrentProcessPod(ch <-chan corev1.Pod, jobs uint, f func(p *corev1.Pod) error) ProcessSummary { +func ConcurrentProcessPod(ch <-chan corev1.Pod, jobs uint, f func(p *corev1.Pod) (bool, error)) ProcessSummary { wg := sync.WaitGroup{} resultCh := make(chan Result) @@ -198,8 +259,8 @@ func ConcurrentProcessPod(ch <-chan corev1.Pod, jobs uint, f func(p *corev1.Pod) go func() { defer wg.Done() for pod := range ch { - err := f(&pod) - resultCh <- Result{Pod: client.ObjectKeyFromObject(&pod).String(), Err: err} + skip, err := f(&pod) + resultCh <- Result{Pod: client.ObjectKeyFromObject(&pod).String(), Err: err, Skip: skip} } }() } @@ -213,6 +274,9 @@ func ConcurrentProcessPod(ch <-chan corev1.Pod, jobs uint, f func(p *corev1.Pod) } for result := range resultCh { ps.TotalPods++ + if result.Skip { + ps.SkippedPods++ + } if result.Err != nil { ps.FailedPods++ estr := result.Err.Error() diff --git a/cmd/importer/util/util_test.go b/cmd/importer/util/util_test.go new file mode 100644 index 0000000000..1271eb151a --- /dev/null +++ b/cmd/importer/util/util_test.go @@ -0,0 +1,143 @@ +package util + +import ( + "os" + "path" + "testing" + + "github.com/google/go-cmp/cmp" +) + +const ( + testContent = ` +- match: + priorityClassName: preemptible + labels: + resource_type: cpu-only + toLocalQueue: preemptible-cpu +- match: + labels: + project_id: alpha + resource_type: gpu + toLocalQueue: alpha-gpu +- match: + labels: + project_id: alpha + resource_type: cpu + skip: true +` +) + +var testMappingRules = MappingRules{ + { + Match: MappingMatch{ + PriorityClassName: "preemptible", + Labels: map[string]string{ + "resource_type": "cpu-only", + }, + }, + ToLocalQueue: "preemptible-cpu", + }, + { + Match: MappingMatch{ + Labels: map[string]string{ + "project_id": "alpha", + "resource_type": "gpu", + }, + }, + ToLocalQueue: "alpha-gpu", + }, + { + Match: MappingMatch{ + Labels: map[string]string{ + "project_id": "alpha", + "resource_type": "cpu", + }, + }, + Skip: true, + }, +} + +func TestMappingFromFile(t *testing.T) { + tdir := t.TempDir() + fPath := path.Join(tdir, "mapping.yaml") + err := os.WriteFile(fPath, []byte(testContent), os.FileMode(0600)) + if err != nil { + t.Fatalf("unable to create the test file: %s", err) + } + + mapping, err := MappingRulesFromFile(fPath) + if err != nil { + t.Fatalf("unexpected load error: %s", err) + } + + if diff := cmp.Diff(testMappingRules, mapping); diff != "" { + t.Errorf("unexpected mapping(want-/ got+):\n%s", diff) + } +} + +func TestMappingMatch(t *testing.T) { + cases := map[string]struct { + className string + labels map[string]string + rules MappingRules + + wantMatch bool + wantSkip bool + wantQueue string + }{ + "missing one label": { + labels: map[string]string{"project_id": "alpha"}, + rules: testMappingRules, + }, + "priority class not checked if not part of the rule": { + className: "preemptible", + labels: map[string]string{"project_id": "alpha", "resource_type": "gpu"}, + rules: testMappingRules, + + wantMatch: true, + wantQueue: "alpha-gpu", + }, + "skip": { + className: "preemptible", + labels: map[string]string{"project_id": "alpha", "resource_type": "cpu"}, + rules: testMappingRules, + + wantMatch: true, + wantSkip: true, + }, + "priority class not matching": { + className: "preemptible-1", + labels: map[string]string{"resource_type": "cpu-only"}, + rules: testMappingRules, + }, + "priority class matching": { + className: "preemptible", + labels: map[string]string{"resource_type": "cpu-only"}, + rules: testMappingRules, + + wantMatch: true, + wantQueue: "preemptible-cpu", + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + gotQueue, gotSkip, gotMatch := tc.rules.QueueFor(tc.className, tc.labels) + + if tc.wantMatch != gotMatch { + t.Errorf("unexpected match %v", gotMatch) + } + + if tc.wantSkip != gotSkip { + t.Errorf("unexpected skip %v", gotSkip) + } + + if tc.wantQueue != gotQueue { + t.Errorf("unexpected queue want %q got %q", tc.wantQueue, gotQueue) + } + + }) + } + +} diff --git a/go.mod b/go.mod index 03d8e29a80..fae100da11 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/ray-project/kuberay/ray-operator v1.1.0-rc.1 github.com/spf13/cobra v1.8.0 go.uber.org/zap v1.27.0 - gopkg.in/yaml.v2 v2.4.0 k8s.io/api v0.29.3 k8s.io/apimachinery v0.29.3 k8s.io/apiserver v0.29.3 @@ -33,6 +32,7 @@ require ( sigs.k8s.io/controller-tools v0.14.0 sigs.k8s.io/jobset v0.4.0 sigs.k8s.io/structured-merge-diff/v4 v4.4.1 + sigs.k8s.io/yaml v1.4.0 ) require ( @@ -121,11 +121,11 @@ require ( google.golang.org/protobuf v1.33.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiextensions-apiserver v0.29.0 // indirect k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 // indirect k8s.io/kms v0.29.3 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect - sigs.k8s.io/yaml v1.4.0 // indirect ) diff --git a/test/integration/importer/importer_test.go b/test/integration/importer/importer_test.go index 1f0a9c9bab..416a5c2d20 100644 --- a/test/integration/importer/importer_test.go +++ b/test/integration/importer/importer_test.go @@ -89,7 +89,7 @@ var _ = ginkgo.Describe("Importer", func() { }) ginkgo.By("Running the import", func() { - mapping, err := importerutil.LoadImportCache(ctx, k8sClient, []string{ns.Name}, "src.lbl", map[string]string{"src-val": "lq1"}, nil) + mapping, err := importerutil.LoadImportCache(ctx, k8sClient, []string{ns.Name}, importerutil.MappingRulesForLabel("src.lbl", map[string]string{"src-val": "lq1"}), nil) gomega.Expect(err).ToNot(gomega.HaveOccurred()) gomega.Expect(mapping).ToNot(gomega.BeNil())