Picture of the author
Published on

ทำ JWT Refresh Token ด้วย Spring Boot

Auth API Spring Boot

ขอบคุณรูปภาพจาก https://i.ytimg.com

แนะนำว่าควรเข้าใจเรื่อง Json Web Token (JWT) มาก่อน เพื่อความเข้าใจของผู้อ่านเอง

Spring Boot ผมจะไม่สร้างใหม่ แต่ผมจะใช้อันเดียวกับบทความนี้ Authentication ด้วย Spring Security & JWT เอามา Implement ต่อ Source Code และผมจะ Config ให้ดูแค่ฝั่งของ Back-end แล้วทดสอบด้วย Postman

ทำไมต้องทำ Refresh Token

แค่มี Access Token ปกติก็พอแล้วนี่ ? เรามาดูข้อเสียของมันก่อนดีกว่า

  1. การทำ JWT เป็นการทำ Authorize แบบสั้นๆ Access Token มีเวลาหมดอายุ พอหมดอายุแล้ว User ต้อง Login ใหม่เพื่อสร้าง Access Token ใหม่ ซึ่งการที่ User ต้อง Login ใหม่บ่อยๆ ก็คงจะน่ารำคาญไม่ใช่น้อย ทำให้ประสบการณ์การใช้งานอาจจะไม่ดีเท่าไหร่
  2. ไม่สามารถระงับหรือบังคับให้ Access Token หมดอายุได้ ถ้าเราถูกขโมย Access Token ไปไม่สามารถทำอะไรได้เลย User ต้อง Logout และ Login อีกครั้ง เพื่อเคลียร์​ Access Token อันเดิมออกไป หรือจนกว่า Access Token จะหมดอายุไปเอง
  3. ย้อนไปข้อ 1. ถ้าบอกว่างั้นก็ทำให้ Access Token ใช้งานได้นานๆสิ นั่นแหล่ะครับ ยิ่งไม่มีความปลอดภัยเลย ถ้ามีคนมาขโมย Access Token ของเราไปได้ แบบที่เราไม่รู้ตัว คนๆนั้นก็จะยิ่งมีเวลาใช้งาน User ของเรานานมากขึ้นตามที่เรากำหนดไว้ การกำหนดเวลาหมดอายุอย่างมากเลย ไม่ควรมากกว่า 15 นาที

แล้วการทำ Refresh Token มันช่วยอะไรล่ะ ?

หลักการทำงานของ Refresh Token จะประมาณนี้

เราจะกำหนดเวลาหมดอายุของ Refresh Token ไว้ที่ 1 วัน และกำหนดเวลาหมดอายุของ Access Token อยู่ที่ 5 นาที และเมื่อครบ 5 นาที ระบบของเราจะรู้ว่า Access Token หมดอายุแล้ว และจะ Auto Request ไปสร้าง Access Token ใหม่ โดยส่ง Refresh Token ไปด้วยเสมอเพื่อเช็คว่า Refresh Token หมดอายุหรือยัง ถ้ายังก็จะสร้าง Access Token ใหม่ โดยที่ User ไม่รู้ตัวและไม่จำเป็นต้อง Login ใหม่ User จะ Login ใหม่ก็ต่อเมื่อ Refresh Token หมดอายุแล้วเท่านั้น

  1. มีความปลอดภัยมากขึ้น (เพราะปรับให้ Access Token หมดอายุเร็วกกว่าปกติได้) โดยที่ User ไม่ต้อง Login ใหม่บ่อยๆ ประสบการณ์การใข้งานจะดีขึ้น
  2. สามารถเลิกใช้งาน Access Token ได้ เพราะเราเก็บ Refresh Token ล่าสุดของ User แต่ล่ะคนไว้ใน DB แค่ปรับให้หมดอายุ เมื่อ User เข้ามาทำอะไรกับ Website สักอย่างก็จะเด้งออกไปให้ Login ใหม่ทันที

Implements

application.propperties

เพิ่มเวลาหมดอายุของ Refresh Token

application.properties
# JWT Config
jwtSecret=auth-spring-jwt-docker-secret-key
# 5 minutes.
jwtExpirationMs=300000
# 24 hours.
jwtRefreshExpirationMs=86400000

## For test
# 1 minute.
#jwtExpirationMs=60000
# 2 minutes.
#jwtRefreshExpirationMs= 120000

. . .

Entity

เพิ่มตารางสำหรับเก็บ Refresh Token ของแต่ล่ะ User

/entity/app/RefreshToken.java
package com.demo.auth.entity.app;

import java.time.Instant;

import com.demo.auth.entity.base.BaseEntity;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.OneToOne;
import lombok.Data;

@Data
@Entity
public class RefreshToken extends BaseEntity{
    @OneToOne
    @JoinColumn(name = "user_id", referencedColumnName = "id")
    private User user;

    @Column(nullable = false, unique = true)
    private String token;

    @Column(nullable = false)
    private Instant expiryDate;
}


. . .

Repository

มี Entity แล้วก็ต้องมี Repository สำหรับทำ CRUD

/repository/RefreshTokenRepository.java
package com.demo.auth.repository;

import java.util.Optional;

import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.stereotype.Repository;

import com.demo.auth.entity.app.RefreshToken;
import com.demo.auth.entity.app.User;

@Repository
public interface RefreshTokenRepository extends JpaRepository<RefreshToken,Long>{
    Optional<RefreshToken> findByToken(String token);
    Optional<RefreshToken> findByUser(User user);

    @Modifying
    int deleteByUser(User user);
}

สำหรับ 1 User จะมีแค่ 1 Refresh Token เท่านั้นเมื่อได้ Refresh Token ใหม่จะต้องลบอันเก่าทิ้ง และถ้าไม่ได้ SELECT จะต้องมี @Modifying


. . .

Exception

เราจะแยก Exception ออกมาต่างหาก เพื่อ Custom ได้ว่าจะมีค่าอะไรบ้าง เอาไว้ใช้กับ @RestControllerAdvice

/exception/TokenRefreshException.java
package com.demo.auth.exception;

import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.http.HttpStatus;

@ResponseStatus(HttpStatus.FORBIDDEN)
public class TokenRefreshException extends RuntimeException {
    private static final long serialVersionUID = 1L;

    public TokenRefreshException(String token, String message) {
      super(String.format("Failed for [%s]: %s", token, message));
    }
}

. . .

Service

/security/services/RefreshTokenService.java
package com.demo.auth.security.services;

import java.time.Instant;
import java.util.Optional;
import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import com.demo.auth.entity.app.RefreshToken;
import com.demo.auth.entity.app.User;
import com.demo.auth.repository.RefreshTokenRepository;
import com.demo.auth.repository.UserRepository;
import com.demo.auth.exception.TokenRefreshException;

@Service
public class RefreshTokenService {

    @Value("${jwtRefreshExpirationMs}")
    private Long refreshTokenDurationMs;

    @Autowired
    private RefreshTokenRepository refreshTokenRepository;

    @Autowired
    private UserRepository userRepository;

    public Optional<RefreshToken> findByToken(String token) {
        return refreshTokenRepository.findByToken(token);
    }

    public RefreshToken createRefreshToken(Long userId) {
        User user = userRepository.findById(userId).get();

        Optional<RefreshToken> oldRefreshToken = refreshTokenRepository.findByUser(user);
        if(oldRefreshToken.isPresent()){
            refreshTokenRepository.deleteByUser(user);
        }

        RefreshToken refreshToken = new RefreshToken();

        refreshToken.setUser(user);
        refreshToken.setExpiryDate(Instant.now().plusMillis(refreshTokenDurationMs));
        refreshToken.setToken(UUID.randomUUID().toString());

        refreshToken = refreshTokenRepository.save(refreshToken);
        return refreshToken;
    }

    public RefreshToken verifyExpiration(RefreshToken token) {
        if (token.getExpiryDate().compareTo(Instant.now()) < 0) {
            refreshTokenRepository.delete(token);
            throw new TokenRefreshException(token.getToken(), "Refresh token was expired. Please make a new signin request");
        }

        return token;
    }
}

มีด้วยกัน 4 Functions

  • ค้นหา Refresh Token ด้วย Token
  • สร้าง Refresh Token ก่อนจะสร้าง มีการเช็คเงื่อนไขเพื่อลบ Refresh Token เก่า
  • เช็คว่า Refresh Token หมดอายุหรือยัง ถ้าหมดอายุแล้วจะลบทิ้ง

. . .

POJO Class

/payload/request/TokenRefreshRequest.java
package com.demo.auth.payload.request;

import jakarta.validation.constraints.NotBlank;
import lombok.Data;

@Data
public class TokenRefreshRequest {
    @NotBlank
    private String refreshToken;
}

/payload/response/TokenRefreshResponse.java
package com.demo.auth.payload.response;

import lombok.Data;

@Data
public class TokenRefreshResponse {
    private String accessToken;
    private String refreshToken;
    private String tokenType = "Bearer";

    public TokenRefreshResponse(String accessToken, String refreshToken) {
        this.accessToken = accessToken;
        this.refreshToken = refreshToken;
    }
}

สำหรับใช้กับ @RestControllerAdvice

/advice/ErrorMessage.java
package com.demo.auth.advice;

import java.util.Date;

import lombok.Data;

@Data
public class ErrorMessage {
    private int statusCode;
    private Date timestamp;
    private String message;
    private String description;
  
    public ErrorMessage(int statusCode, Date timestamp, String message, String description) {
      this.statusCode = statusCode;
      this.timestamp = timestamp;
      this.message = message;
      this.description = description;
    }
}

เพิ่ม refreshToken

/payload/response/JwtResponse.java

//...
public class JwtResponse {

    //...
    private String refreshToken;

    //...

    public JwtResponse(
        //...
        String refreshToken, 
        //...
        ) 
    {
        //...
        this.refreshToken = refreshToken;
        //...
    }
}

. . .

Filter

ปรับ AuthTokenFilter.java เช็ค RefreshToken ใน Headers เพิ่ม

/security/jwt/AuthTokenFilter.java
package com.demo.auth.security.jwt;

import java.io.IOException;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import com.demo.auth.entity.app.RefreshToken;
import com.demo.auth.security.services.RefreshTokenService;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.log4j.Log4j2;

import java.util.Optional;

@Log4j2
@Component
public class AuthTokenFilter extends OncePerRequestFilter{

    //...

	@Autowired
	private RefreshTokenService refreshTokenService;

    @Override
	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
			throws ServletException, IOException {
		try {
			//...

			String rft = parseRft(request);
			Optional<RefreshToken> refreshTokenOpt = refreshTokenService.findByToken(rft);
			
			RefreshToken refreshToken = null;
			if(refreshTokenOpt.isPresent()){
				refreshToken = refreshTokenService.verifyExpiration(refreshTokenOpt.get());
			}

			//...
		} catch (Exception e) {
			log.error("Cannot set user authentication: {}", e);
		}

		filterChain.doFilter(request, response);
	}

    //...

	private String parseRft(HttpServletRequest request) {
		String headerRft = request.getHeader("RefreshToken");
		if (StringUtils.hasText(headerRft)) {
			return headerRft;
		}

		return null;
	}
    
}

ที่ต้องเพิ่มการส่ง RefreshToken ใน Header มาด้วยทุกครั้ง เพราะว่าถ้าเราปรับ RefreshToken ให้หมดอายุ AccessToken ก็จะไม่สามารถใช้งานได้ไปด้วย

ถ้าไม่มีส่ง RefreshToken มาด้วย เมื่อเราปรับ RefreshToken ให้หมดอายุ แต่ถ้า AcessToken ยังไม่หมดอายุ ก็จะยังสามารถใช้งานได้เหมือนเดิม


. . .

Advice

การทำ @RestControllerAdvide เป็นการทำ Handler สำหรับ Custom Error ว่าเวลาเกิด Error อยากให้มันส่งค่าอะไรกลับไปบอกบ้าง

/advice/TokenControllerAdvice.java
package com.demo.auth.advice;

import java.util.Date;

import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import org.springframework.web.context.request.WebRequest;

import com.demo.auth.exception.TokenRefreshException;

@RestControllerAdvice
public class TokenControllerAdvice {

    @ExceptionHandler(value = TokenRefreshException.class)
    @ResponseStatus(HttpStatus.FORBIDDEN)
    public ErrorMessage handleTokenRefreshException(TokenRefreshException ex, WebRequest request) {
        return new ErrorMessage(
            HttpStatus.FORBIDDEN.value(),
            new Date(),
            ex.getMessage(),
            request.getDescription(false));
    }

}

เรา Trigger ไว้กับ TokenRefreshException.java เมื่อเรียกใช้ก็จะเข้ามาที่นี่ก่อน Return error ออกไป


. . .

Controller

ปรับ Logic ใน /auth/signin นิดหน่อย

/controllers/AuthController.java

import com.demo.auth.entity.app.RefreshToken;

@CrossOrigin(origins = "*", maxAge = 3600)
@RestController
@RequestMapping("/auth")
public class AuthController {
    //...

    @Autowired
  	RefreshTokenService refreshTokenService;

    @PostMapping("/signin")
	public ResponseEntity<?> authenticateUser(@Valid @RequestBody SigninRequest req) {
        //...
        UserDetailsImpl userDetails = (UserDetailsImpl) authentication.getPrincipal();	
		String jwt = jwtUtil.generateJwtToken(userDetails);
        //...
        RefreshToken refreshToken = refreshTokenService.createRefreshToken(userDetails.getId());

        return ResponseEntity.ok(new JwtResponse(
                                                //...
												refreshToken.getToken()
                                                //...
                                                 ));
    }
}

เพิ่ม /refreshtoken

/controllers/AuthController.java

import com.demo.auth.entity.app.RefreshToken;

import com.demo.auth.payload.request.TokenRefreshRequest;
import com.demo.auth.payload.response.TokenRefreshResponse;
import com.demo.auth.exception.TokenRefreshException;

@CrossOrigin(origins = "*", maxAge = 3600)
@RestController
@RequestMapping("/auth")
public class AuthController {
    //...

    @PostMapping("/refreshtoken")
	public ResponseEntity<?> refreshtoken(@Valid @RequestBody TokenRefreshRequest request) {
		String requestRefreshToken = request.getRefreshToken();

		return refreshTokenService.findByToken(requestRefreshToken)
			.map(refreshTokenService::verifyExpiration)
			.map(RefreshToken::getUser)
			.map(user -> {
					String token = jwtUtil.generateTokenFromUsername(user.getUsername());
					return ResponseEntity.ok(new TokenRefreshResponse(token, requestRefreshToken));
				}
			)
			.orElseThrow(() -> new TokenRefreshException(requestRefreshToken,"Refresh token is not in database!"));
	}
}

ทดสอบด้วย Postman

  1. เพิ่มตัวแปร RefreshToken

  1. เขียน Script ใน Tests ของ /auth/signin เพื่อ Set ค่าตัวแปรให้กับ RefreshToken
if(pm.response.code === 200){
    pm.environment.set("AccessToken",pm.response.json().token);
    pm.environment.set("RefreshToken",pm.response.json().refreshToken);
}

  1. /auth/signin
PM Signin

จะเห็นว่าได้ refreshToken กลับมาด้วย


ตาราง refresh_token

DB

  1. /user เพิ่มการส่ง RefreshToken ใน Header โดยเอาค่ามาจากตัวแปร {{RefreshToken}}
PM User
PM User Authen

  1. สร้าง Request /auth/refreshtoken
PM Refresh Token

ใน Script Tests ทำเหมือนข้อ 2


  1. เมื่อ AccessToken หมดอายุ เราจะได้ Response Code = 401 และฝั่ง Front-End ต้องทำ Interceptor ของ Response ไว้ เมื่อได้รับ Response Code = 401 ให้ Request ไปที่ /auth/refreshtoken เพื่อสร้าง AcessToken อันใหม่ และนำกลับมาอัพเดตฝั่ง Front-End และถ้า RefreshToken หมดอายุก็จะ Logout ให้อัตโนมัติ

ตัวอย่างประมาณนี้

setupinterceptors.js
import http from '@/constants/api';
import TokenService from '@/services/token'
import EventBus from '@/utils/event-bus'

const setup = (store) => {
    http.interceptors.response.use(
        (res) => {
          return res;
        },
        async (err) => {
          const originalConfig = err.config;
    
          if (originalConfig.url !== "/auth/signin" && err.response) {
            // Access Token was expired
            if (err.response.status === 401 && !originalConfig._retry) {
              originalConfig._retry = true;
    
              try {
                const rs = await http.post("/auth/refreshtoken", {
                  refreshToken: TokenService.getLocalRefreshToken(),
                });
    
                const { accessToken } = rs.data;
    
                store.dispatch('auth/refreshtoken', accessToken);
                TokenService.updateLocalAccessToken(accessToken);
    
                return http(originalConfig);
              } catch (_error) {
                EventBus.dispatch("logout");
                return Promise.reject(_error);
              }
            }
          }
    
          return Promise.reject(err);
        }
      );
};

export default setup;

Log ใน Spring Boot ถ้า AccessToken หมดอายุ

JWT token is expired: JWT expired at 2023-03-18T13:51:18Z. Current time: 2023-03-18T14:06:59Z, 
a difference of 941979 milliseconds.  Allowed clock skew: 0 milliseconds.