96 lines
2.9 KiB
C#
96 lines
2.9 KiB
C#
using System.Text;
|
||
using TinfoilVibeServer.Authentication;
|
||
|
||
namespace TinfoilVibeServer.Middleware;
|
||
|
||
/// <summary>
|
||
/// Minimal Basic‑Auth middleware that also checks UID, failure counters and a blacklist.
|
||
/// </summary>
|
||
public sealed class BasicAuthMiddleware
|
||
{
|
||
private readonly RequestDelegate _next;
|
||
|
||
public BasicAuthMiddleware(RequestDelegate next)
|
||
{
|
||
_next = next;
|
||
}
|
||
|
||
public async Task InvokeAsync(HttpContext context, AuthStore store, ILogger<BasicAuthMiddleware> logger)
|
||
{
|
||
var ip = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
|
||
|
||
// 1) IP blacklist
|
||
if (store.IsBlacklisted(ip))
|
||
{
|
||
logger.LogWarning("Blocked request from blacklisted IP {IP}", ip);
|
||
context.Response.StatusCode = StatusCodes.Status403Forbidden;
|
||
await context.Response.WriteAsync("Forbidden");
|
||
return;
|
||
}
|
||
|
||
// 2) Authorization header
|
||
if (!context.Request.Headers.TryGetValue("Authorization", out var authHeaders))
|
||
{
|
||
Challenge(context);
|
||
return;
|
||
}
|
||
|
||
var authHeader = authHeaders.FirstOrDefault() ?? "";
|
||
if (!authHeader.StartsWith("Basic ", StringComparison.OrdinalIgnoreCase))
|
||
{
|
||
Challenge(context);
|
||
return;
|
||
}
|
||
|
||
string decoded;
|
||
try
|
||
{
|
||
var b64 = authHeader[6..].Trim();
|
||
decoded = Encoding.UTF8.GetString(Convert.FromBase64String(b64));
|
||
}
|
||
catch
|
||
{
|
||
Challenge(context);
|
||
return;
|
||
}
|
||
|
||
var parts = decoded.Split(':', 2);
|
||
if (parts.Length != 2)
|
||
{
|
||
Challenge(context);
|
||
return;
|
||
}
|
||
|
||
var username = parts[0];
|
||
var password = parts[1];
|
||
|
||
// 3) UID header (optional)
|
||
int? uid = null;
|
||
if (context.Request.Headers.TryGetValue("UID", out var uidHeader))
|
||
{
|
||
if (int.TryParse(uidHeader.ToString(), out var parsedUid))
|
||
uid = parsedUid;
|
||
}
|
||
|
||
// 4) Validate
|
||
if (!store.TryValidate(username, password, uid, ip, out var error))
|
||
{
|
||
logger.LogWarning("Auth failed for user {User} from {IP}: {Error}", username, ip, error);
|
||
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
||
context.Response.Headers.Add("WWW-Authenticate", "Basic realm=\"FileSnapshot\"");
|
||
await context.Response.WriteAsync(error ?? "Unauthorized");
|
||
return;
|
||
}
|
||
|
||
// Authentication succeeded – attach username for downstream handlers if needed
|
||
context.Items["User"] = username;
|
||
|
||
await _next(context);
|
||
}
|
||
|
||
private static void Challenge(HttpContext ctx)
|
||
{
|
||
ctx.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
||
ctx.Response.Headers.Add("WWW-Authenticate", "Basic realm=\"FileSnapshot\"");
|
||
}
|
||
} |