blob: 7e1e65cfdb0113fb470023aef098a034285b044f [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001package client
2
3import (
4 "encoding/json"
5 "fmt"
6 "sort"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/jcmturner/gokrb5/v8/iana/nametype"
12 "github.com/jcmturner/gokrb5/v8/krberror"
13 "github.com/jcmturner/gokrb5/v8/messages"
14 "github.com/jcmturner/gokrb5/v8/types"
15)
16
17// sessions hold TGTs and are keyed on the realm name
18type sessions struct {
19 Entries map[string]*session
20 mux sync.RWMutex
21}
22
23// destroy erases all sessions
24func (s *sessions) destroy() {
25 s.mux.Lock()
26 defer s.mux.Unlock()
27 for k, e := range s.Entries {
28 e.destroy()
29 delete(s.Entries, k)
30 }
31}
32
33// update replaces a session with the one provided or adds it as a new one
34func (s *sessions) update(sess *session) {
35 s.mux.Lock()
36 defer s.mux.Unlock()
37 // if a session already exists for this, cancel its auto renew.
38 if i, ok := s.Entries[sess.realm]; ok {
39 if i != sess {
40 // Session in the sessions cache is not the same as one provided.
41 // Cancel the one in the cache and add this one.
42 i.mux.Lock()
43 defer i.mux.Unlock()
Abhay Kumara2ae5992025-11-10 14:02:24 +000044 if i.cancel != nil {
45 i.cancel <- true
46 }
khenaidood948f772021-08-11 17:49:24 -040047 s.Entries[sess.realm] = sess
48 return
49 }
50 }
51 // No session for this realm was found so just add it
52 s.Entries[sess.realm] = sess
53}
54
55// get returns the session for the realm specified
56func (s *sessions) get(realm string) (*session, bool) {
57 s.mux.RLock()
58 defer s.mux.RUnlock()
59 sess, ok := s.Entries[realm]
60 return sess, ok
61}
62
63// session holds the TGT details for a realm
64type session struct {
65 realm string
66 authTime time.Time
67 endTime time.Time
68 renewTill time.Time
69 tgt messages.Ticket
70 sessionKey types.EncryptionKey
71 sessionKeyExpiration time.Time
72 cancel chan bool
73 mux sync.RWMutex
74}
75
76// jsonSession is used to enable marshaling some information of a session in a JSON format
77type jsonSession struct {
78 Realm string
79 AuthTime time.Time
80 EndTime time.Time
81 RenewTill time.Time
82 SessionKeyExpiration time.Time
83}
84
85// AddSession adds a session for a realm with a TGT to the client's session cache.
86// A goroutine is started to automatically renew the TGT before expiry.
87func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
88 if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
89 // Not a TGT
90 return
91 }
92 realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
93 s := &session{
94 realm: realm,
95 authTime: dep.AuthTime,
96 endTime: dep.EndTime,
97 renewTill: dep.RenewTill,
98 tgt: tgt,
99 sessionKey: dep.Key,
100 sessionKeyExpiration: dep.KeyExpiration,
101 }
102 cl.sessions.update(s)
103 cl.enableAutoSessionRenewal(s)
104 cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
105}
106
107// update overwrites the session details with those from the TGT and decrypted encPart
108func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
109 s.mux.Lock()
110 defer s.mux.Unlock()
111 s.authTime = dep.AuthTime
112 s.endTime = dep.EndTime
113 s.renewTill = dep.RenewTill
114 s.tgt = tgt
115 s.sessionKey = dep.Key
116 s.sessionKeyExpiration = dep.KeyExpiration
117}
118
119// destroy will cancel any auto renewal of the session and set the expiration times to the current time
120func (s *session) destroy() {
121 s.mux.Lock()
122 defer s.mux.Unlock()
123 if s.cancel != nil {
124 s.cancel <- true
125 }
126 s.endTime = time.Now().UTC()
127 s.renewTill = s.endTime
128 s.sessionKeyExpiration = s.endTime
129}
130
131// valid informs if the TGT is still within the valid time window
132func (s *session) valid() bool {
133 s.mux.RLock()
134 defer s.mux.RUnlock()
135 t := time.Now().UTC()
136 if t.Before(s.endTime) && s.authTime.Before(t) {
137 return true
138 }
139 return false
140}
141
142// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
143func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
144 s.mux.RLock()
145 defer s.mux.RUnlock()
146 return s.realm, s.tgt, s.sessionKey
147}
148
149// timeDetails is a thread safe way to get the session's validity time values
150func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
151 s.mux.RLock()
152 defer s.mux.RUnlock()
153 return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
154}
155
156// JSON return information about the held sessions in a JSON format.
157func (s *sessions) JSON() (string, error) {
158 s.mux.RLock()
159 defer s.mux.RUnlock()
160 var js []jsonSession
161 keys := make([]string, 0, len(s.Entries))
162 for k := range s.Entries {
163 keys = append(keys, k)
164 }
165 sort.Strings(keys)
166 for _, k := range keys {
167 r, at, et, rt, kt := s.Entries[k].timeDetails()
168 j := jsonSession{
169 Realm: r,
170 AuthTime: at,
171 EndTime: et,
172 RenewTill: rt,
173 SessionKeyExpiration: kt,
174 }
175 js = append(js, j)
176 }
177 b, err := json.MarshalIndent(js, "", " ")
178 if err != nil {
179 return "", err
180 }
181 return string(b), nil
182}
183
184// enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
185func (cl *Client) enableAutoSessionRenewal(s *session) {
186 var timer *time.Timer
187 s.mux.Lock()
188 s.cancel = make(chan bool, 1)
189 s.mux.Unlock()
190 go func(s *session) {
191 for {
192 s.mux.RLock()
193 w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
194 s.mux.RUnlock()
195 if w < 0 {
196 return
197 }
198 timer = time.NewTimer(w)
199 select {
200 case <-timer.C:
201 renewal, err := cl.refreshSession(s)
202 if err != nil {
203 cl.Log("error refreshing session: %v", err)
204 }
205 if !renewal && err == nil {
206 // end this goroutine as there will have been a new login and new auto renewal goroutine created.
207 return
208 }
209 case <-s.cancel:
210 // cancel has been called. Stop the timer and exit.
211 timer.Stop()
212 return
213 }
214 }
215 }(s)
216}
217
218// renewTGT renews the client's TGT session.
219func (cl *Client) renewTGT(s *session) error {
220 realm, tgt, skey := s.tgtDetails()
221 spn := types.PrincipalName{
222 NameType: nametype.KRB_NT_SRV_INST,
223 NameString: []string{"krbtgt", realm},
224 }
225 _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
226 if err != nil {
227 return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
228 }
229 s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
230 cl.sessions.update(s)
231 cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
232 return nil
233}
234
235// refreshSession updates either through renewal or creating a new login.
236// The boolean indicates if the update was a renewal.
237func (cl *Client) refreshSession(s *session) (bool, error) {
238 s.mux.RLock()
239 realm := s.realm
240 renewTill := s.renewTill
241 s.mux.RUnlock()
242 cl.Log("refreshing TGT session for %s", realm)
243 if time.Now().UTC().Before(renewTill) {
244 err := cl.renewTGT(s)
245 return true, err
246 }
247 err := cl.realmLogin(realm)
248 return false, err
249}
250
251// ensureValidSession makes sure there is a valid session for the realm
252func (cl *Client) ensureValidSession(realm string) error {
253 s, ok := cl.sessions.get(realm)
254 if ok {
255 s.mux.RLock()
256 d := s.endTime.Sub(s.authTime) / 6
257 if s.endTime.Sub(time.Now().UTC()) > d {
258 s.mux.RUnlock()
259 return nil
260 }
261 s.mux.RUnlock()
262 _, err := cl.refreshSession(s)
263 return err
264 }
265 return cl.realmLogin(realm)
266}
267
268// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
269func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
270 err = cl.ensureValidSession(realm)
271 if err != nil {
272 return
273 }
274 s, ok := cl.sessions.get(realm)
275 if !ok {
276 err = fmt.Errorf("could not find TGT session for %s", realm)
277 return
278 }
279 _, tgt, sessionKey = s.tgtDetails()
280 return
281}
282
283// sessionTimes provides the timing information with regards to a session for the realm specified.
284func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
285 s, ok := cl.sessions.get(realm)
286 if !ok {
287 err = fmt.Errorf("could not find TGT session for %s", realm)
288 return
289 }
290 _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
291 return
292}
293
294// spnRealm resolves the realm name of a service principal name
295func (cl *Client) spnRealm(spn types.PrincipalName) string {
296 return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
297}