| Abhay Kumar | a61c522 | 2025-11-10 07:32:50 +0000 | [diff] [blame^] | 1 | // Copyright 2013-2023 The Cobra Authors |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | package cobra |
| 16 | |
| 17 | import ( |
| 18 | "fmt" |
| 19 | "sort" |
| 20 | "strings" |
| 21 | |
| 22 | flag "github.com/spf13/pflag" |
| 23 | ) |
| 24 | |
| 25 | const ( |
| 26 | requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" |
| 27 | oneRequiredAnnotation = "cobra_annotation_one_required" |
| 28 | mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" |
| 29 | ) |
| 30 | |
| 31 | // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors |
| 32 | // if the command is invoked with a subset (but not all) of the given flags. |
| 33 | func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { |
| 34 | c.mergePersistentFlags() |
| 35 | for _, v := range flagNames { |
| 36 | f := c.Flags().Lookup(v) |
| 37 | if f == nil { |
| 38 | panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) |
| 39 | } |
| 40 | if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil { |
| 41 | // Only errs if the flag isn't found. |
| 42 | panic(err) |
| 43 | } |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | // MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors |
| 48 | // if the command is invoked without at least one flag from the given set of flags. |
| 49 | func (c *Command) MarkFlagsOneRequired(flagNames ...string) { |
| 50 | c.mergePersistentFlags() |
| 51 | for _, v := range flagNames { |
| 52 | f := c.Flags().Lookup(v) |
| 53 | if f == nil { |
| 54 | panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) |
| 55 | } |
| 56 | if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { |
| 57 | // Only errs if the flag isn't found. |
| 58 | panic(err) |
| 59 | } |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors |
| 64 | // if the command is invoked with more than one flag from the given set of flags. |
| 65 | func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { |
| 66 | c.mergePersistentFlags() |
| 67 | for _, v := range flagNames { |
| 68 | f := c.Flags().Lookup(v) |
| 69 | if f == nil { |
| 70 | panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) |
| 71 | } |
| 72 | // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. |
| 73 | if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil { |
| 74 | panic(err) |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the |
| 80 | // first error encountered. |
| 81 | func (c *Command) ValidateFlagGroups() error { |
| 82 | if c.DisableFlagParsing { |
| 83 | return nil |
| 84 | } |
| 85 | |
| 86 | flags := c.Flags() |
| 87 | |
| 88 | // groupStatus format is the list of flags as a unique ID, |
| 89 | // then a map of each flag name and whether it is set or not. |
| 90 | groupStatus := map[string]map[string]bool{} |
| 91 | oneRequiredGroupStatus := map[string]map[string]bool{} |
| 92 | mutuallyExclusiveGroupStatus := map[string]map[string]bool{} |
| 93 | flags.VisitAll(func(pflag *flag.Flag) { |
| 94 | processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) |
| 95 | processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) |
| 96 | processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) |
| 97 | }) |
| 98 | |
| 99 | if err := validateRequiredFlagGroups(groupStatus); err != nil { |
| 100 | return err |
| 101 | } |
| 102 | if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { |
| 103 | return err |
| 104 | } |
| 105 | if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { |
| 106 | return err |
| 107 | } |
| 108 | return nil |
| 109 | } |
| 110 | |
| 111 | func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { |
| 112 | for _, fname := range flagnames { |
| 113 | f := fs.Lookup(fname) |
| 114 | if f == nil { |
| 115 | return false |
| 116 | } |
| 117 | } |
| 118 | return true |
| 119 | } |
| 120 | |
| 121 | func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { |
| 122 | groupInfo, found := pflag.Annotations[annotation] |
| 123 | if found { |
| 124 | for _, group := range groupInfo { |
| 125 | if groupStatus[group] == nil { |
| 126 | flagnames := strings.Split(group, " ") |
| 127 | |
| 128 | // Only consider this flag group at all if all the flags are defined. |
| 129 | if !hasAllFlags(flags, flagnames...) { |
| 130 | continue |
| 131 | } |
| 132 | |
| 133 | groupStatus[group] = make(map[string]bool, len(flagnames)) |
| 134 | for _, name := range flagnames { |
| 135 | groupStatus[group][name] = false |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | groupStatus[group][pflag.Name] = pflag.Changed |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | func validateRequiredFlagGroups(data map[string]map[string]bool) error { |
| 145 | keys := sortedKeys(data) |
| 146 | for _, flagList := range keys { |
| 147 | flagnameAndStatus := data[flagList] |
| 148 | |
| 149 | unset := []string{} |
| 150 | for flagname, isSet := range flagnameAndStatus { |
| 151 | if !isSet { |
| 152 | unset = append(unset, flagname) |
| 153 | } |
| 154 | } |
| 155 | if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { |
| 156 | continue |
| 157 | } |
| 158 | |
| 159 | // Sort values, so they can be tested/scripted against consistently. |
| 160 | sort.Strings(unset) |
| 161 | return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) |
| 162 | } |
| 163 | |
| 164 | return nil |
| 165 | } |
| 166 | |
| 167 | func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { |
| 168 | keys := sortedKeys(data) |
| 169 | for _, flagList := range keys { |
| 170 | flagnameAndStatus := data[flagList] |
| 171 | var set []string |
| 172 | for flagname, isSet := range flagnameAndStatus { |
| 173 | if isSet { |
| 174 | set = append(set, flagname) |
| 175 | } |
| 176 | } |
| 177 | if len(set) >= 1 { |
| 178 | continue |
| 179 | } |
| 180 | |
| 181 | // Sort values, so they can be tested/scripted against consistently. |
| 182 | sort.Strings(set) |
| 183 | return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList) |
| 184 | } |
| 185 | return nil |
| 186 | } |
| 187 | |
| 188 | func validateExclusiveFlagGroups(data map[string]map[string]bool) error { |
| 189 | keys := sortedKeys(data) |
| 190 | for _, flagList := range keys { |
| 191 | flagnameAndStatus := data[flagList] |
| 192 | var set []string |
| 193 | for flagname, isSet := range flagnameAndStatus { |
| 194 | if isSet { |
| 195 | set = append(set, flagname) |
| 196 | } |
| 197 | } |
| 198 | if len(set) == 0 || len(set) == 1 { |
| 199 | continue |
| 200 | } |
| 201 | |
| 202 | // Sort values, so they can be tested/scripted against consistently. |
| 203 | sort.Strings(set) |
| 204 | return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) |
| 205 | } |
| 206 | return nil |
| 207 | } |
| 208 | |
| 209 | func sortedKeys(m map[string]map[string]bool) []string { |
| 210 | keys := make([]string, len(m)) |
| 211 | i := 0 |
| 212 | for k := range m { |
| 213 | keys[i] = k |
| 214 | i++ |
| 215 | } |
| 216 | sort.Strings(keys) |
| 217 | return keys |
| 218 | } |
| 219 | |
| 220 | // enforceFlagGroupsForCompletion will do the following: |
| 221 | // - when a flag in a group is present, other flags in the group will be marked required |
| 222 | // - when none of the flags in a one-required group are present, all flags in the group will be marked required |
| 223 | // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden |
| 224 | // This allows the standard completion logic to behave appropriately for flag groups |
| 225 | func (c *Command) enforceFlagGroupsForCompletion() { |
| 226 | if c.DisableFlagParsing { |
| 227 | return |
| 228 | } |
| 229 | |
| 230 | flags := c.Flags() |
| 231 | groupStatus := map[string]map[string]bool{} |
| 232 | oneRequiredGroupStatus := map[string]map[string]bool{} |
| 233 | mutuallyExclusiveGroupStatus := map[string]map[string]bool{} |
| 234 | c.Flags().VisitAll(func(pflag *flag.Flag) { |
| 235 | processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) |
| 236 | processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) |
| 237 | processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) |
| 238 | }) |
| 239 | |
| 240 | // If a flag that is part of a group is present, we make all the other flags |
| 241 | // of that group required so that the shell completion suggests them automatically |
| 242 | for flagList, flagnameAndStatus := range groupStatus { |
| 243 | for _, isSet := range flagnameAndStatus { |
| 244 | if isSet { |
| 245 | // One of the flags of the group is set, mark the other ones as required |
| 246 | for _, fName := range strings.Split(flagList, " ") { |
| 247 | _ = c.MarkFlagRequired(fName) |
| 248 | } |
| 249 | } |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | // If none of the flags of a one-required group are present, we make all the flags |
| 254 | // of that group required so that the shell completion suggests them automatically |
| 255 | for flagList, flagnameAndStatus := range oneRequiredGroupStatus { |
| 256 | isSet := false |
| 257 | |
| 258 | for _, isSet = range flagnameAndStatus { |
| 259 | if isSet { |
| 260 | break |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | // None of the flags of the group are set, mark all flags in the group |
| 265 | // as required |
| 266 | if !isSet { |
| 267 | for _, fName := range strings.Split(flagList, " ") { |
| 268 | _ = c.MarkFlagRequired(fName) |
| 269 | } |
| 270 | } |
| 271 | } |
| 272 | |
| 273 | // If a flag that is mutually exclusive to others is present, we hide the other |
| 274 | // flags of that group so the shell completion does not suggest them |
| 275 | for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { |
| 276 | for flagName, isSet := range flagnameAndStatus { |
| 277 | if isSet { |
| 278 | // One of the flags of the mutually exclusive group is set, mark the other ones as hidden |
| 279 | // Don't mark the flag that is already set as hidden because it may be an |
| 280 | // array or slice flag and therefore must continue being suggested |
| 281 | for _, fName := range strings.Split(flagList, " ") { |
| 282 | if fName != flagName { |
| 283 | flag := c.Flags().Lookup(fName) |
| 284 | flag.Hidden = true |
| 285 | } |
| 286 | } |
| 287 | } |
| 288 | } |
| 289 | } |
| 290 | } |