A Simple On-Demand NAT Gateway for AWS

Adam Luchjenbroers
5 min readFeb 27, 2021

This pet project grew out of a need to tinker with the AWS CI/CD tools (CodePipeline / CodeBuild / CodeDeploy) as part of studying up for the AWS Developer Associate exam. I quickly realised that if I wanted to use CodeBuild, I’d have to run either a NAT Gateway or several private endpoints (at a minimum, it seemed I’d need an S3 Gateway and interface endpoints for at least CodeCommit and Cloudwatch Logs). These can be costly to keep running, especially if you’re only likely to use them for a few hours a week.

Since it’s possible to manage everything in AWS via APIs, I decided that working out how to provision and deactivate the NAT gateway automatically would make for a great learning project.

The core of the solution consists of two Lambda functions. The first, RequestNatGateway, determines if a NAT gateway is currently available and, if not, handles the provisioning of a new gateway and the creation of the required routes. The build process calls this early in each run to ensure the gateway is available before it is needed.

The second, CheckGatewayRequired, checks if the gateway is running and whether it is still required. It runs every 30 minutes or so and will shut down the gateway if there has been no request for it within the last 45 minutes. These intervals seemed ideal to ensure that the gateway would not shut itself down prematurely. Since NAT gateways bill for each hour or part thereof, this should be slightly cheaper and more convenient than a solution that shuts the gateway down more quickly (which might manage to incur two hours worth of cost within a single hour).

Tags are used for several purposes as part of the solution. They’re used to identify which subnets are public, which route tables need to be updated, and to track how recently the gateway was last requested.

The source code for this solution, along with a CloudFormation template for deploying it along with a properly tagged and configured VPC, can be found in this GitHub repository: https://github.com/AdamLuchjenbroers/cheapseats-vpc

Code Walkthrough

So that’s the overview done; let’s dig into the actual code. It’s all written in Python and uses boto3 to interact with the AWS API and jmespath to parse the JSON data returned by boto3/AWS. This solution uses simple print statements to log to Cloudwatch logs (there’s probably a more elegant way to do this, but it works quite well enough for this purpose). The source code for these lambda functions can be found here: https://github.com/AdamLuchjenbroers/cheapseats-vpc/blob/master/Lambda/OnDemandNAT/RequestGateway.py

Let’s start with RequestNatGateway, which is implemented in this code as request_gateway_handler(). This function first checks to see if there are any active NAT gateways within the VPC using list_nat_gateways():

def list_nat_gateways():
filters = [
{'Name': 'state', 'Values' : ['pending', 'available']},
{'Name': 'vpc-id' , 'Values' : [ os.environ['VPC_ID'] ]}
]

gateway_json = ec2.describe_nat_gateways(Filters=filters)
gateways = jmespath.search('NatGateways[*].[NatGatewayId, State, CreateTime, Tags[?Key==\'LastRequested\'].Value | [0]]', gateway_json)

if len(gateways) > 0:
print("Checking for existing gateways, found ---\n%s\n---\n" % gateways)
return gateways
else:
print("Checking for existing gateways, found none\n")
return None

If we determine that a gateway doesn't already exist then we will provision it. If there is already a gateway in operation, we update a tag that contains the last requested timestamp so that CheckGatewayRequired knows that we’re still using it.

def request_gateway_handler(event, context):
print("NAT Gateway Requested\n")
try:
gateway_list = list_nat_gateways()

info = {
'nat_needed' : 'requested'
}

if gateway_list == None:
print("New Gateway Required, Launching\n")
gatewayId = create_nat_gateway()
update_route_tables(gatewayId)
info['nat-launched'] = gatewayId
else:
print("NAT Gateway already provisioned - updating Last Requested Timestamp\n")
info['nat-existing'] = True
for (gatewayId, state, created, lastRequested) in gateway_list:
ec2.create_tags(
Resources=[gatewayId]
, Tags=[ {'Key' : 'LastRequested', 'Value' : '%s' % datetime.utcnow() } ]
)
...

Creating a gateway uses two more utility functions, firstly we call create_nat_gateway() to launch a NAT Gateway in one of the public subnets (chosen at random). We identify these using a tag that must be set on the public subnets (this is done automatically as part of the CloudFormation template). We ensure that the gateway is fully started and operational before we allow the function to return.

def create_nat_gateway():
alloc_json = ec2.describe_addresses(Filters=[
{'Name' : 'tag:Name', 'Values' : ['OnDemandNAT-IPAddr']},
{'Name' : 'tag:ForVpc', 'Values' : [os.environ['VPC_NAME']]}
])
allocId = jmespath.search('Addresses[0].AllocationId', alloc_json )

subnet_json = ec2.describe_subnets(Filters=[
{'Name' : 'tag:Public', 'Values' : ['Yes']},
{'Name' : 'vpc-id', 'Values' : [os.environ['VPC_ID']]}
])
subnet_list = jmespath.search('Subnets[*].SubnetId', subnet_json)
subnetId = random.choice(subnet_list)

new_gw_json = ec2.create_nat_gateway(AllocationId=allocId, SubnetId=subnetId)
gatewayId = jmespath.search('NatGateway.NatGatewayId' , new_gw_json)

print('NAT Gateway Created\n\tID: %s\tInfo ---\n%s\n---\n' % (gatewayId, new_gw_json))

ec2.create_tags(
Resources=[gatewayId]
, Tags=[
{'Key' : 'OnDemandNAT', 'Value' : 'True'}
, {'Key' : 'Name', 'Value' : 'OnDemandNAT-Gateway'}
, {'Key' : 'LastRequested', 'Value' : '%s' % datetime.utcnow()}
, {'Key' : 'ForVpc', 'Value' : os.environ['VPC_NAME']}
, {'Key' : 'Application', 'Value' : 'OnDemandNAT'}
, {'Key' : 'Environment', 'Value' : 'Infrastructure'}
]
)
# Wait for gateway to finish starting.
waiter = ec2.get_waiter('nat_gateway_available')
waiter.wait(NatGatewayIds = [gatewayId])

return gatewayId

Next, we call update_route_tables() to add the necessary routes. Each route table that is tagged as requiring on-demand NAT is updated with a new default route to the NAT gateway. It first removes any default route that already exists before adding the new route to avoid an error caused by adding a route for a destination that already exists in the route table.

def update_route_tables(gatewayId):
routes_json = ec2.describe_route_tables(Filters=[{'Name' : 'tag:OnDemandNAT', 'Values' : ['Yes', 'True']}])
routes_list = jmespath.search('RouteTables[*].RouteTableId', routes_json)

print("Fetched Routes List for update ---\n%s\n---\n" % routes_list)

for routeTableId in routes_list:
print("Updating Route Table %s" % routeTableId)
try:
ec2.delete_route(RouteTableId = routeTableId, DestinationCidrBlock = '0.0.0.0/0')
except ClientError as e:
# We expect the occasional failure where the route doesn't exist - this can be safely ignored.
if e.response['Error']['Code'] != 'InvalidRoute.NotFound':
raise e
ec2.create_route(RouteTableId = routeTableId, DestinationCidrBlock = '0.0.0.0/0', NatGatewayId = gatewayId)

print("Update Completed for %s\n" % routeTableId)

Last of all, we return to request_gateway_handler() to check if we were invoked via CodePipeline. CodePipeline needs us to post a result back to them using its API; failing to do this will result in the CodePipeline step timing out and failing regardless of whether the gateway was successfully provisioned or not.

...        if 'CodePipeline.job' in event:
job = event['CodePipeline.job']

cp = boto3.client('codepipeline')
cp.put_job_success_result(
jobId=job['id']
)

print("SUMMARY:\n%s\n" % json.dumps(info))
return info
except BaseException as e:
if 'CodePipeline.job' in event:
job = event['CodePipeline.job']

cp = boto3.client('codepipeline')
cp.put_job_failure_result(
jobId=job['id'],
failureDetails= {
"type" : "JobFailed"
, "message" : '%s' % e
}
)
raise e

CheckGatewayRequired is implemented by check_gateway_required(). It simply checks for active NAT gateways using list_nat_gateways() and then cycles through them to check when each was last requested. Any that have not been requested in the past 45 minutes are terminated.

def check_gateway_required(event, context):
gateway_list = list_nat_gateways()

info = {}
gw_change_list = []

if gateway_list == None:
# Nothing to do.
print("No Gateway running, nothing to do")
return
print("Active gateway detected, checking gateway ages")
for (gatewayId, state, created, lastRequested) in gateway_list:
age = datetime.now(created.tzinfo) - created

# If we have a last requested date, we use that, but if not we fall back
# to using the age of the gateway
if lastRequested != None:
inactive = datetime.utcnow() - parser.isoparse(lastRequested)
else:
inactive = age

if inactive >= timedelta(minutes=45):
ec2.delete_nat_gateway(NatGatewayId = gatewayId)
print("Gateway %s detected as inactive, terminated" % gatewayId)
gw_change_list.append({'action' : 'deleted', 'gatewayId' : gatewayId, 'age' : ('%s' % age), 'inactive' : ('%s' % inactive)})
else:
print("Gateway %s is still active, skipped" % gatewayId)
gw_change_list.append({'action' : 'skipped', 'gatewayId' : gatewayId, 'age' : ('%s' % age), 'inactive' : ('%s' % inactive)})
info['nat-changed'] = gw_change_list

print("SUMMARY:\n%s\n" % json.dumps(info))
return info

--

--

Adam Luchjenbroers

I’m a data warehousing engineer in my professional life, but I enjoy dabbling in a variety of pet projects when I get the time.