Added much more authentication infrastructure for various kinds of authentication.

This commit is contained in:
Andrew Lalis 2023-11-07 16:24:24 -05:00
parent c7bf9d9058
commit f4ba6b8bec
16 changed files with 401 additions and 77 deletions

View File

@ -9,28 +9,68 @@ import org.springframework.web.bind.annotation.*;
import java.util.Map;
/**
* Controller for authentication-related tasks.
*/
@RestController
@RequiredArgsConstructor
public class AuthController {
private final TokenService tokenService;
/**
* The endpoint for users of this Onyx node to login, that is, obtain a
* refresh token and access token in exchange for valid credentials.
* @param loginRequest The login request.
* @return A token pair, if successful.
*/
@PostMapping("/auth/login")
public TokenPair login(@RequestBody LoginRequest loginRequest) {
return tokenService.generateTokenPair(loginRequest);
}
/**
* Endpoint for obtaining a new access token using a valid refresh token.
* @param request The HTTP request.
* @return The new access token.
*/
@GetMapping("/auth/access")
public AccessTokenResponse getAccessToken(HttpServletRequest request) {
return tokenService.generateAccessToken(request);
}
/**
* Endpoint used to remove all refresh tokens, essentially logging the user
* out of all devices that may have stored a refresh token.
* @param user The user who is removing their tokens.
*/
@DeleteMapping("/auth/refresh-tokens")
public void removeAllRefreshTokens(@AuthenticationPrincipal User user) {
tokenService.removeAllRefreshTokens(user);
}
/**
* Endpoint for determining the expiration time of an access token.
* @param request The HTTP request.
* @return An object containing an "expiresAt" field, in milliseconds since
* the unix epoch.
*/
@GetMapping("/auth/token-expiration")
public Object getTokenExpiration(HttpServletRequest request) {
return Map.of("expiresAt", tokenService.getTokenExpiration(request));
}
/**
* Validates a token belonging to a user of this Onyx node, as requested by
* another node. The request itself should have an Authorization header
* with a bearer token that proves the identity of the onyx node that's
* requesting to verify the user.
* @param request The HTTP request.
* @param validationData The data needed to validate the user.
* @return An object that tells whether the user is verified.
*/
@PostMapping("/auth/validate-foreign-token")
public Object validateToken(HttpServletRequest request, @RequestBody ForeignTokenValidationRequest validationData) {
// TODO: Implement this!
return null;
}
}

View File

@ -0,0 +1,5 @@
package com.andrewlalis.onyx.auth.api;
public record ForeignTokenValidationRequest(
String accessToken
) {}

View File

@ -3,6 +3,7 @@ package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.model.User;
import com.andrewlalis.onyx.content.dao.ContentNodeRepository;
import com.andrewlalis.onyx.content.model.ContentNode;
import com.andrewlalis.onyx.content.service.ContentAccessService;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
@ -25,6 +26,7 @@ import java.io.IOException;
@Slf4j
public class ContentAccessFilter extends OncePerRequestFilter {
private final ContentNodeRepository contentNodeRepository;
private final ContentAccessService contentAccessService;
@Override
protected void doFilterInternal(
@ -48,8 +50,6 @@ public class ContentAccessFilter extends OncePerRequestFilter {
log.warn("Node doesn't exist!");
return;
}
TokenAuthentication auth = (TokenAuthentication) SecurityContextHolder.getContext().getAuthentication();
User user = auth.getPrincipal();
// TODO: Actually check access rules.
filterChain.doFilter(request, response);
}

View File

@ -1,45 +0,0 @@
package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.service.TokenService;
import com.andrewlalis.onyx.auth.service.UserService;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
@Component
@RequiredArgsConstructor
@Slf4j
public class JwtFilter extends OncePerRequestFilter {
private final UserService userService;
private final TokenService tokenService;
@Override
protected void doFilterInternal(
HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
try {
Jws<Claims> jws = tokenService.getToken(request);
if (jws != null) {
long userId = Long.parseLong(jws.getBody().getSubject());
userService.findById(userId).ifPresent(user -> SecurityContextHolder.getContext()
.setAuthentication(new TokenAuthentication(user, jws))
);
}
} catch (Exception e) {
log.warn("Exception occurred in JwtFilter.", e);
}
filterChain.doFilter(request, response);
}
}

View File

@ -0,0 +1,26 @@
package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.model.User;
/**
* A type of token authentication that's used for users of this Onyx node,
* where we have full access to the user and their info.
*/
public class LocalUserAuth extends TokenAuth {
private final User user;
public LocalUserAuth(String token, User user) {
super(token);
this.user = user;
}
@Override
public User getPrincipal() {
return user;
}
@Override
public String getName() {
return user.getUsername();
}
}

View File

@ -0,0 +1,27 @@
package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.model.NetworkUser;
/**
* A type of token authentication that's used for users of Onyx nodes networked
* with this one. We don't have full access to their user data, but we can talk
* to the networked node to get some basic information about the user.
*/
public class NetworkUserAuth extends TokenAuth {
private final NetworkUser user;
public NetworkUserAuth(String token, NetworkUser user) {
super(token);
this.user = user;
}
@Override
public NetworkUser getPrincipal() {
return user;
}
@Override
public String getName() {
return user.username();
}
}

View File

@ -0,0 +1,49 @@
package com.andrewlalis.onyx.auth.components;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import java.util.Collection;
/**
* An authentication instance that's used to represent public access to the
* API.
*/
public class PublicUserAuth implements Authentication {
// TODO: Add some sort of info here, not just authentication.
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return null;
}
@Override
public Object getCredentials() {
return null;
}
@Override
public Object getDetails() {
return null;
}
@Override
public Object getPrincipal() {
return null;
}
@Override
public boolean isAuthenticated() {
return false;
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
}
@Override
public String getName() {
return null;
}
}

View File

@ -1,8 +1,5 @@
package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.model.User;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
@ -10,12 +7,16 @@ import java.util.Collection;
import java.util.Collections;
/**
* The authentication implementation that's used when a user logs in with an
* access token.
* @param user The user that the token belongs to.
* @param jws The raw token.
* An abstract base class for any authentication instance based on the use of
* a JWT for authentication.
*/
public record TokenAuthentication(User user, Jws<Claims> jws) implements Authentication {
public abstract class TokenAuth implements Authentication {
public final String token;
protected TokenAuth(String token) {
this.token = token;
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return Collections.emptyList();
@ -23,7 +24,7 @@ public record TokenAuthentication(User user, Jws<Claims> jws) implements Authent
@Override
public Object getCredentials() {
return this.jws;
return this.token;
}
@Override
@ -31,11 +32,6 @@ public record TokenAuthentication(User user, Jws<Claims> jws) implements Authent
return null;
}
@Override
public User getPrincipal() {
return this.user;
}
@Override
public boolean isAuthenticated() {
return true;
@ -45,9 +41,4 @@ public record TokenAuthentication(User user, Jws<Claims> jws) implements Authent
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
throw new RuntimeException("Cannot set the authenticated status of TokenAuthentication.");
}
@Override
public String getName() {
return user.getUsername();
}
}

View File

@ -0,0 +1,69 @@
package com.andrewlalis.onyx.auth.components;
import com.andrewlalis.onyx.auth.model.User;
import com.andrewlalis.onyx.auth.service.TokenService;
import com.andrewlalis.onyx.auth.service.UserService;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.Optional;
/**
* A filter that extracts and verifies an HTTP request's JWT, and uses that
* to set the current security context's authentication instance accordingly.
* <p>
* If the JWT originates from this Onyx node, then we simply fetch the User
* and set a new {@link LocalUserAuth}.
* </p>
* <p>
* If the JWT originates from a networked Onyx node, then we'll try and
* confirm with the original node that the JWT is valid, and then set a new
* {@link NetworkUserAuth}.
* </p>
*/
@Component
@RequiredArgsConstructor
@Slf4j
public class TokenAuthFilter extends OncePerRequestFilter {
private final UserService userService;
private final TokenService tokenService;
@Override
protected void doFilterInternal(
HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
final var securityContext = SecurityContextHolder.getContext();
boolean authContextSet = false;
try {
String token = tokenService.extractBearerToken(request);
Jws<Claims> jws = tokenService.parseToken(token);
if (jws != null) {
long userId = Long.parseLong(jws.getBody().getSubject());
Optional<User> optionalUser = userService.findById(userId);
if (optionalUser.isPresent()) {
securityContext.setAuthentication(new LocalUserAuth(token, optionalUser.get()));
authContextSet = true;
}
}
} catch (Exception e) {
log.warn("Exception occurred in JwtFilter.", e);
}
// TODO: Check if the request is coming from a network user, then validate their token.
if (!authContextSet) {
securityContext.setAuthentication(new PublicUserAuth());
}
filterChain.doFilter(request, response);
}
}

View File

@ -0,0 +1,15 @@
package com.andrewlalis.onyx.auth.model;
/**
* Similar to a {@link User}, the NetworkUser contains the information for a
* user who's coming to this node from another in the network. We don't have
* as much information as with a normal user, but enough to work with.
* @param id The user's id, as it is defined by its host Onyx node.
* @param username The user's username, as it is defined by its host Onyx node.
* @param displayName The user's display name.
*/
public record NetworkUser(
long id,
String username,
String displayName
) {}

View File

@ -142,7 +142,7 @@ public class TokenService {
public long getTokenExpiration(HttpServletRequest request) {
try {
Jws<Claims> jws = getToken(request);
Jws<Claims> jws = parseToken(extractBearerToken(request));
return jws.getBody().getExpiration().getTime();
} catch (Exception e) {
log.warn("Exception occurred while getting token expiration.", e);
@ -150,16 +150,15 @@ public class TokenService {
}
}
public Jws<Claims> getToken(HttpServletRequest request) throws Exception {
String rawToken = extractBearerToken(request);
if (rawToken == null) return null;
public Jws<Claims> parseToken(String token) throws Exception {
if (token == null) return null;
JwtParserBuilder parserBuilder = Jwts.parserBuilder()
.setSigningKey(getSigningKey())
.requireIssuer(ISSUER);
return parserBuilder.build().parseClaimsJws(rawToken);
return parserBuilder.build().parseClaimsJws(token);
}
private String extractBearerToken(HttpServletRequest request) {
public String extractBearerToken(HttpServletRequest request) {
String authorizationHeader = request.getHeader("Authorization");
if (authorizationHeader == null || !authorizationHeader.startsWith(BEARER_PREFIX)) return null;
String rawToken = authorizationHeader.substring(BEARER_PREFIX.length());

View File

@ -1,7 +1,7 @@
package com.andrewlalis.onyx.config;
import com.andrewlalis.onyx.auth.components.ContentAccessFilter;
import com.andrewlalis.onyx.auth.components.JwtFilter;
import com.andrewlalis.onyx.auth.components.TokenAuthFilter;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@ -12,28 +12,30 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@Configuration
@EnableWebSecurity
@RequiredArgsConstructor
public class SecurityConfig {
private final JwtFilter jwtFilter;
private final TokenAuthFilter tokenAuthFilter;
private final ContentAccessFilter contentAccessFilter;
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.authorizeHttpRequests(registry -> {
// Public endpoints that require no authentication.
registry.requestMatchers(AntPathRequestMatcher.antMatcher(HttpMethod.POST, "/auth/login")).permitAll();
registry.requestMatchers(AntPathRequestMatcher.antMatcher(HttpMethod.GET, "/auth/access")).permitAll();
registry.requestMatchers(AntPathRequestMatcher.antMatcher(HttpMethod.GET, "/auth/token-expiration")).permitAll();
// Any path not explicitly listed here requires authentication to access.
registry.anyRequest().authenticated();
});
http.csrf(AbstractHttpConfigurer::disable);
http.sessionManagement(configurer -> configurer.sessionCreationPolicy(SessionCreationPolicy.NEVER));
http.addFilterBefore(jwtFilter, UsernamePasswordAuthenticationFilter.class);
http.addFilterAfter(contentAccessFilter, JwtFilter.class);
http.addFilterBefore(tokenAuthFilter, UsernamePasswordAuthenticationFilter.class);
http.addFilterAfter(contentAccessFilter, TokenAuthFilter.class);
http.cors(configurer -> configurer.configure(http));
return http.build();
}

View File

@ -0,0 +1,12 @@
package com.andrewlalis.onyx.content.dao;
import com.andrewlalis.onyx.content.model.access.ContentAccessRules;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
import java.util.Optional;
@Repository
public interface ContentAccessRulesRepository extends JpaRepository<ContentAccessRules, Long> {
Optional<ContentAccessRules> findByContentNodeId(long contentNodeId);
}

View File

@ -13,4 +13,11 @@ public interface ContentNodeRepository extends JpaRepository<ContentNode, Long>
@Query("SELECT cn FROM ContentNode cn WHERE cn.name = '" + ContentNode.ROOT_NODE_NAME + "'")
ContentNode findRoot();
interface ParentContainerId {
long getParentContainerId();
}
@Query("SELECT cn.parentContainer.id FROM ContentNode cn WHERE cn.id = :nodeId")
ParentContainerId getParentId(long nodeId);
}

View File

@ -0,0 +1,20 @@
package com.andrewlalis.onyx.content.dao;
import com.andrewlalis.onyx.content.model.access.UserContentAccessRule;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
@Repository
public interface UserContentAccessRuleRepository extends JpaRepository<UserContentAccessRule, Long> {
@Query("SELECT ucar FROM UserContentAccessRule ucar " +
"WHERE ucar.contentAccessRules.contentNode.id = :nodeId AND " +
"ucar.user.id = :userId")
Optional<UserContentAccessRule> findByContentNodeIdAndUser(long nodeId, long userId);
@Query("SELECT ucar FROM UserContentAccessRule ucar WHERE ucar.contentAccessRules.contentNode.id IN :nodeIds")
List<UserContentAccessRule> findAllByContentNodeIds(List<Long> nodeIds);
}

View File

@ -0,0 +1,107 @@
package com.andrewlalis.onyx.content.service;
import com.andrewlalis.onyx.auth.model.User;
import com.andrewlalis.onyx.content.dao.ContentAccessRulesRepository;
import com.andrewlalis.onyx.content.dao.ContentNodeRepository;
import com.andrewlalis.onyx.content.dao.UserContentAccessRuleRepository;
import com.andrewlalis.onyx.content.model.ContentNode;
import com.andrewlalis.onyx.content.model.access.ContentAccessLevel;
import com.andrewlalis.onyx.content.model.access.ContentAccessRules;
import com.andrewlalis.onyx.content.model.access.UserContentAccessRule;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.server.ResponseStatusException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* A service that's responsible for determining if a user has permission to
* interact with certain content nodes.
* TODO: Add some sort of caching so recursive traversal of the content graph isn't needed.
*/
@Service
@RequiredArgsConstructor
public class ContentAccessService {
private final ContentAccessRulesRepository accessRulesRepository;
private final ContentNodeRepository contentNodeRepository;
private final UserContentAccessRuleRepository userContentAccessRuleRepository;
@Transactional(readOnly = true)
public boolean currentAuthCanReadContent(long nodeId) {
User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
var a = getEffectiveAccessLevel(user, nodeId);
return a == ContentAccessLevel.VIEW || a == ContentAccessLevel.EDIT;
}
@Transactional(readOnly = true)
public boolean currentAuthCanEditContent(long nodeId) {
User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
var a = getEffectiveAccessLevel(user, nodeId);
return a == ContentAccessLevel.EDIT;
}
/**
* Gets the effective access level that a user has to a particular content
* node. It finds this using the following algorithm:
* <ol>
* <li>Get a list of ids for this node and all its parents.</li>
* <li>
* First, recursively check the user-specific access levels for
* this and all parent nodes. If there exists a node with a user-
* specific access level for the user, then that's returned.
* </li>
* <li>
* Otherwise, recursively check the generic access levels for this
* and all parent nodes. The first non-INHERIT access level is
* returned.
* </li>
* </ol>
* The root node should not logically ever have an INHERIT access level, so
* it's the last resort if no others are found.
* @param user The user to get the access level for.
* @param nodeId The id of the content node to get the access level for.
* @return The access level that the given user has to the given node.
*/
private ContentAccessLevel getEffectiveAccessLevel(User user, long nodeId) {
List<Long> nodeIds = getAllNodeIds(nodeId);
for (long nId : nodeIds) {
Optional<UserContentAccessRule> userAccessRule = userContentAccessRuleRepository.findByContentNodeIdAndUser(nId, user.getId());
if (userAccessRule.isPresent() && userAccessRule.get().getAccessLevel() != ContentAccessLevel.INHERIT) {
return userAccessRule.get().getAccessLevel();
}
}
for (long nId : nodeIds) {
ContentAccessRules accessRules = accessRulesRepository.findByContentNodeId(nId).orElseThrow();
// TODO: Check the user's origin: anonymous, network, or node.
// For now, we assume node.
if (accessRules.getNodeAccessLevel() != ContentAccessLevel.INHERIT) {
return accessRules.getNodeAccessLevel();
}
}
return ContentAccessLevel.NONE;
}
private List<Long> getAllNodeIds(long nodeId) {
List<Long> nodeIds = new ArrayList<>();
nodeIds.add(nodeId);
ContentNode node = contentNodeRepository.findById(nodeId)
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND));
ContentNode parent = node.getParentContainer();
while (parent != null) {
nodeIds.add(parent.getId());
parent = parent.getParentContainer();
}
return nodeIds;
}
private enum ContentAccessType {
PUBLIC,
NETWORK,
NODE
}
}