Taylohtio/GeneralSSO/GeneralSSO.Server/CodeFiles/Infrastructure/WCF/OAuthAuthorizationManager.cs

112 lines
4.6 KiB
C#

using System;
using System.Collections.Generic;
using System.IdentityModel.Policy;
using System.Net;
using System.Security.Cryptography;
using System.Security.Principal;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Security;
using System.ServiceModel.Web;
using System.Web;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.OAuth2;
using Microsoft.Practices.ServiceLocation;
using Taloyhtio.GeneralSSO.Server.CodeFiles.Common;
using Taloyhtio.GeneralSSO.Server.CodeFiles.Infrastructure.OAuth;
using Taloyhtio.GeneralSSO.Server.CodeFiles.Repositories;
using Taloyhtio.GeneralSSO.Server.CodeFiles.Services;
using ProtocolException = System.ServiceModel.ProtocolException;
namespace Taloyhtio.GeneralSSO.Server.CodeFiles.Infrastructure.WCF
{
// WCF extension to authenticate incoming messages using OAuth.
public class OAuthAuthorizationManager : ServiceAuthorizationManager
{
private ILogger logger;
public OAuthAuthorizationManager()
{
this.logger = ServiceLocator.Current.GetInstance<ILogger>();
}
protected override bool CheckAccessCore(OperationContext operationContext)
{
if (!base.CheckAccessCore(operationContext))
{
return false;
}
var httpDetails = operationContext.RequestContext.RequestMessage.Properties[HttpRequestMessageProperty.Name]
as HttpRequestMessageProperty;
var requestUri = operationContext.RequestContext.RequestMessage.Properties.Via;
try
{
string scope = this.getScopeByUrl(operationContext.IncomingMessageHeaders.To.AbsolutePath);
var principal = this.verifyOAuth2(httpDetails, requestUri, scope);
if (principal == null)
{
return false;
}
var policy = new OAuthPrincipalAuthorizationPolicy(principal);
var policies = new List<IAuthorizationPolicy> { policy };
var securityContext = new ServiceSecurityContext(policies.AsReadOnly());
if (operationContext.IncomingMessageProperties.Security != null)
{
operationContext.IncomingMessageProperties.Security.ServiceSecurityContext = securityContext;
}
else
{
operationContext.IncomingMessageProperties.Security = new SecurityMessageProperty
{
ServiceSecurityContext = securityContext,
};
}
securityContext.AuthorizationContext.Properties["Identities"] = new List<IIdentity> { principal.Identity };
return true;
}
catch (ProtocolFaultResponseException ex)
{
this.logger.Error(Constants.LogComponents.RESOURCE_SERVER, string.Format("Error processing OAuth messages (protocol fault response):\n{0}", ex.ToInfo()));
// return the appropriate unauthorized response to the client
var outgoingResponse = ex.CreateErrorResponse();
outgoingResponse.Respond(WebOperationContext.Current.OutgoingResponse);
}
catch (ProtocolException ex)
{
this.logger.Error(Constants.LogComponents.RESOURCE_SERVER, string.Format("Error processing OAuth messages:\n{0}", ex.ToInfo()));
}
return false;
}
private IPrincipal verifyOAuth2(HttpRequestMessageProperty httpDetails, Uri requestUri, params string[] requiredScopes)
{
var tokenAnalyzer =
new StandardAccessTokenAnalyzer(
(RSACryptoServiceProvider) Cert.AuthServerSigningCertificate.PublicKey.Key,
(RSACryptoServiceProvider) Cert.ResourceServerEncyptionCertificate.PrivateKey);
var resourceServer = new ResourceServer(tokenAnalyzer);
return resourceServer.GetPrincipal(httpDetails, requestUri, requiredScopes);
}
private string getScopeByUrl(string requestUrl)
{
if (string.IsNullOrEmpty(requestUrl))
{
throw new HttpException((int)HttpStatusCode.BadRequest, "Request url is empty");
}
requestUrl = requestUrl.ToLower();
if (requestUrl.EndsWith("/getroles"))
{
return Scope.ReadRoles.ToString();
}
// empty scope will cause insufficient_scope error
return string.Empty;
}
}
}