2025-10-30 11:13:38 +01:00

715 lines
27 KiB
Python

# -*- coding: utf-8 -*-
# Copyright: (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import annotations
DOCUMENTATION = r"""
---
module: wait_for
short_description: Waits for a condition before continuing
description:
- You can wait for a set amount of time O(timeout), this is the default if nothing is specified or just O(timeout) is specified.
This does not produce an error.
- Waiting for a port to become available is useful for when services are not immediately available after their init scripts return
which is true of certain Java application servers.
- It is also useful when starting guests with the M(community.libvirt.virt) module and needing to pause until they are ready.
- This module can also be used to wait for a regex match a string to be present in a file.
- In Ansible 1.6 and later, this module can also be used to wait for a file to be available or
absent on the filesystem.
- In Ansible 1.8 and later, this module can also be used to wait for active connections to be closed before continuing, useful if a node
is being rotated out of a load balancer pool.
- For Windows targets, use the M(ansible.windows.win_wait_for) module instead.
version_added: "0.7"
options:
host:
description:
- A resolvable hostname or IP address to wait for.
type: str
default: 127.0.0.1
timeout:
description:
- Maximum number of seconds to wait for, when used with another condition it will force an error.
- When used without other conditions it is equivalent of just sleeping.
type: int
default: 300
connect_timeout:
description:
- Maximum number of seconds to wait for a connection to happen before closing and retrying.
type: int
default: 5
delay:
description:
- Number of seconds to wait before starting to poll.
type: int
default: 0
port:
description:
- Port number to poll.
- O(path) and O(port) are mutually exclusive parameters.
type: int
active_connection_states:
description:
- The list of TCP connection states which are counted as active connections.
type: list
elements: str
default: [ ESTABLISHED, FIN_WAIT1, FIN_WAIT2, SYN_RECV, SYN_SENT, TIME_WAIT ]
version_added: "2.3"
state:
description:
- Either V(present), V(started), or V(stopped), V(absent), or V(drained).
- When checking a port V(started) will ensure the port is open, V(stopped) will check that it is closed, V(drained) will check for active connections.
- When checking for a file or a search string V(present) or V(started) will ensure that the file or string is present before continuing,
V(absent) will check that file is absent or removed.
type: str
choices: [ absent, drained, present, started, stopped ]
default: started
path:
description:
- Path to a file on the filesystem that must exist before continuing.
- O(path) and O(port) are mutually exclusive parameters.
type: path
version_added: "1.4"
search_regex:
description:
- Can be used to match a string in either a file or a socket connection.
- Defaults to a multiline regex.
- When inspecting a system log file and a static string, remember that Ansible by default logs its own actions there;
see the notes and examples for information.
type: str
version_added: "1.4"
exclude_hosts:
description:
- List of hosts or IPs to ignore when looking for active TCP connections for V(drained) state.
type: list
elements: str
version_added: "1.8"
sleep:
description:
- Number of seconds to sleep between checks.
- Before Ansible 2.3 this was hardcoded to 1 second.
type: int
default: 1
version_added: "2.3"
msg:
description:
- This overrides the normal error message from a failure to meet the required conditions.
type: str
version_added: "2.4"
extends_documentation_fragment: action_common_attributes
attributes:
check_mode:
support: none
diff_mode:
support: none
platform:
platforms: posix
notes:
- Under some circumstances when using mandatory access control, a path may always be treated as being absent even if it exists, but
can't be modified or created by the remote user either.
- When waiting for a path, symbolic links will be followed. Many other modules that manipulate files do not follow symbolic links,
so operations on the path using other modules may not work exactly as expected.
- When searching a static string within a system log file, it is important to account for potential self-matching against log entries
generated by the Ansible modules. To prevent this, add a regular expression construct into the search string. For example, to match
a literal string 'this thing', one could use a regular expression like 'this t[h]ing'.
seealso:
- module: ansible.builtin.wait_for_connection
- module: ansible.windows.win_wait_for
- module: community.windows.win_wait_for_process
author:
- Jeroen Hoekx (@jhoekx)
- John Jarvis (@jarv)
- Andrii Radyk (@AnderEnder)
"""
EXAMPLES = r"""
- name: Sleep for 300 seconds and continue with play
ansible.builtin.wait_for:
timeout: 300
delegate_to: localhost
- name: Wait for port 8000 to become open on the host, don't start checking for 10 seconds
ansible.builtin.wait_for:
port: 8000
delay: 10
- name: Waits for port 8000 of any IP to close active connections, don't start checking for 10 seconds
ansible.builtin.wait_for:
host: 0.0.0.0
port: 8000
delay: 10
state: drained
- name: Wait for port 8000 of any IP to close active connections, ignoring connections for specified hosts
ansible.builtin.wait_for:
host: 0.0.0.0
port: 8000
state: drained
exclude_hosts: 10.2.1.2,10.2.1.3
- name: Wait until the file /tmp/foo is present before continuing
ansible.builtin.wait_for:
path: /tmp/foo
- name: Wait until the string "completed" is in the file /tmp/foo before continuing
ansible.builtin.wait_for:
path: /tmp/foo
search_regex: completed
- name: Wait until the string "tomcat up" is in syslog, use regex character set to avoid self match
ansible.builtin.wait_for:
path: /var/log/syslog
search_regex: 'tomcat [u]p'
- name: Wait until regex pattern matches in the file /tmp/foo and print the matched group
ansible.builtin.wait_for:
path: /tmp/foo
search_regex: completed (?P<task>\w+)
register: waitfor
- ansible.builtin.debug:
msg: Completed {{ waitfor['match_groupdict']['task'] }}
- name: Wait until the lock file is removed
ansible.builtin.wait_for:
path: /var/lock/file.lock
state: absent
- name: Wait until the process is finished and pid was destroyed
ansible.builtin.wait_for:
path: /proc/3466/status
state: absent
- name: Output customized message when failed
ansible.builtin.wait_for:
path: /tmp/foo
state: present
msg: Timeout to find file /tmp/foo
# Do not assume the inventory_hostname is resolvable and delay 10 seconds at start
- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
ansible.builtin.wait_for:
port: 22
host: '{{ (ansible_ssh_host|default(ansible_host))|default(inventory_hostname) }}'
search_regex: OpenSSH
delay: 10
timeout: 300
delegate_to: localhost
# Same as above but using config lookup for the target,
# most plugins use 'remote_addr', but ssh uses 'host'
- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
ansible.builtin.wait_for:
port: 22
host: "{{ lookup('config', 'host', plugin_name='ssh', plugin_type='connection') }}"
search_regex: OpenSSH
delay: 10
timeout: 300
delegate_to: localhost
"""
RETURN = r"""
elapsed:
description: The number of seconds that elapsed while waiting
returned: always
type: int
sample: 23
match_groups:
description: Tuple containing all the subgroups of the match as returned by U(https://docs.python.org/3/library/re.html#re.MatchObject.groups)
returned: always
type: list
sample: ['match 1', 'match 2']
match_groupdict:
description: Dictionary containing all the named subgroups of the match, keyed by the subgroup name,
as returned by U(https://docs.python.org/3/library/re.html#re.MatchObject.groupdict)
returned: always
type: dict
sample:
{
'group': 'match'
}
"""
import binascii
import contextlib
import errno
import math
import mmap
import os
import re
import select
import socket
import time
from datetime import datetime, timedelta, timezone
from ansible.module_utils.basic import AnsibleModule, missing_required_lib
from ansible.module_utils.common.sys_info import get_platform_subclass
from ansible.module_utils.common.text.converters import to_bytes, to_native
HAS_PSUTIL = False
PSUTIL_IMP_ERR = None
try:
import psutil
HAS_PSUTIL = True
# just because we can import it on Linux doesn't mean we will use it
except ImportError as ex:
PSUTIL_IMP_ERR = ex
class TCPConnectionInfo(object):
"""
This is a generic TCP Connection Info strategy class that relies
on the psutil module, which is not ideal for targets, but necessary
for cross platform support.
A subclass may wish to override some or all of these methods.
- _get_exclude_ips()
- get_active_connections()
All subclasses MUST define platform and distribution (which may be None).
"""
platform = 'Generic'
distribution = None
match_all_ips = {
socket.AF_INET: '0.0.0.0',
socket.AF_INET6: '::',
}
ipv4_mapped_ipv6_address = {
'prefix': '::ffff',
'match_all': '::ffff:0.0.0.0'
}
def __new__(cls, *args, **kwargs):
new_cls = get_platform_subclass(TCPConnectionInfo)
return super(cls, new_cls).__new__(new_cls)
def __init__(self, module):
self.module = module
self.ips = _convert_host_to_ip(module.params['host'])
self.port = int(self.module.params['port'])
self.exclude_ips = self._get_exclude_ips()
if not HAS_PSUTIL:
module.fail_json(msg=missing_required_lib('psutil'), exception=PSUTIL_IMP_ERR)
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts']
exclude_ips = []
if exclude_hosts is not None:
for host in exclude_hosts:
exclude_ips.extend(_convert_host_to_ip(host))
return exclude_ips
def get_active_connections_count(self):
active_connections = 0
for p in psutil.process_iter():
try:
if hasattr(p, 'get_connections'):
connections = p.get_connections(kind='inet')
else:
connections = p.connections(kind='inet')
except psutil.Error:
# Process is Zombie or other error state
continue
for conn in connections:
if conn.status not in self.module.params['active_connection_states']:
continue
if hasattr(conn, 'local_address'):
(local_ip, local_port) = conn.local_address
else:
(local_ip, local_port) = conn.laddr
if self.port != local_port:
continue
if hasattr(conn, 'remote_address'):
(remote_ip, remote_port) = conn.remote_address
else:
(remote_ip, remote_port) = conn.raddr
if (conn.family, remote_ip) in self.exclude_ips:
continue
if any((
(conn.family, local_ip) in self.ips,
(conn.family, self.match_all_ips[conn.family]) in self.ips,
local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
(conn.family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
)):
active_connections += 1
return active_connections
# ===========================================
# Subclass: Linux
class LinuxTCPConnectionInfo(TCPConnectionInfo):
"""
This is a TCP Connection Info evaluation strategy class
that utilizes information from Linux's procfs. While less universal,
does allow Linux targets to not require an additional library.
"""
platform = 'Linux'
distribution = None
source_file = {
socket.AF_INET: '/proc/net/tcp',
socket.AF_INET6: '/proc/net/tcp6'
}
match_all_ips = {
socket.AF_INET: '00000000',
socket.AF_INET6: '00000000000000000000000000000000',
}
ipv4_mapped_ipv6_address = {
'prefix': '0000000000000000FFFF0000',
'match_all': '0000000000000000FFFF000000000000'
}
local_address_field = 1
remote_address_field = 2
connection_state_field = 3
def __init__(self, module):
self.module = module
self.ips = _convert_host_to_hex(module.params['host'])
self.port = "%0.4X" % int(module.params['port'])
self.exclude_ips = self._get_exclude_ips()
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts']
exclude_ips = []
if exclude_hosts is not None:
for host in exclude_hosts:
exclude_ips.extend(_convert_host_to_hex(host))
return exclude_ips
def get_active_connections_count(self):
active_connections = 0
for family in self.source_file.keys():
if not os.path.isfile(self.source_file[family]):
continue
try:
with open(self.source_file[family]) as f:
for tcp_connection in f.readlines():
tcp_connection = tcp_connection.strip().split()
if tcp_connection[self.local_address_field] == 'local_address':
continue
if (tcp_connection[self.connection_state_field] not in
[get_connection_state_id(_connection_state) for _connection_state in self.module.params['active_connection_states']]):
continue
(local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
if self.port != local_port:
continue
(remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
if (family, remote_ip) in self.exclude_ips:
continue
if any((
(family, local_ip) in self.ips,
(family, self.match_all_ips[family]) in self.ips,
local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
(family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
)):
active_connections += 1
except OSError:
pass
return active_connections
def _convert_host_to_ip(host):
"""
Perform forward DNS resolution on host, IP will give the same IP
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
List of tuples containing address family and IP
"""
addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)
ips = []
for family, socktype, proto, canonname, sockaddr in addrinfo:
ip = sockaddr[0]
ips.append((family, ip))
if family == socket.AF_INET:
ips.append((socket.AF_INET6, "::ffff:" + ip))
return ips
def _convert_host_to_hex(host):
"""
Convert the provided host to the format in /proc/net/tcp*
/proc/net/tcp uses little-endian four byte hex for ipv4
/proc/net/tcp6 uses little-endian per 4B word for ipv6
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
List of tuples containing address family and the
little-endian converted host
"""
ips = []
if host is not None:
for family, ip in _convert_host_to_ip(host):
hexip_nf = binascii.b2a_hex(socket.inet_pton(family, ip))
hexip_hf = ""
for i in range(0, len(hexip_nf), 8):
ipgroup_nf = hexip_nf[i:i + 8]
ipgroup_hf = socket.ntohl(int(ipgroup_nf, base=16))
hexip_hf = "%s%08X" % (hexip_hf, ipgroup_hf)
ips.append((family, hexip_hf))
return ips
def _timedelta_total_seconds(timedelta):
return (
timedelta.microseconds + 0.0 +
(timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6) / 10 ** 6
def get_connection_state_id(state):
connection_state_id = {
'ESTABLISHED': '01',
'SYN_SENT': '02',
'SYN_RECV': '03',
'FIN_WAIT1': '04',
'FIN_WAIT2': '05',
'TIME_WAIT': '06',
}
return connection_state_id[state]
def main():
module = AnsibleModule(
argument_spec=dict(
host=dict(type='str', default='127.0.0.1'),
timeout=dict(type='int', default=300),
connect_timeout=dict(type='int', default=5),
delay=dict(type='int', default=0),
port=dict(type='int'),
active_connection_states=dict(type='list', elements='str', default=['ESTABLISHED', 'FIN_WAIT1', 'FIN_WAIT2', 'SYN_RECV', 'SYN_SENT', 'TIME_WAIT']),
path=dict(type='path'),
search_regex=dict(type='str'),
state=dict(type='str', default='started', choices=['absent', 'drained', 'present', 'started', 'stopped']),
exclude_hosts=dict(type='list', elements='str'),
sleep=dict(type='int', default=1),
msg=dict(type='str'),
),
)
host = module.params['host']
timeout = module.params['timeout']
connect_timeout = module.params['connect_timeout']
delay = module.params['delay']
port = module.params['port']
state = module.params['state']
path = module.params['path']
b_path = to_bytes(path, errors='surrogate_or_strict', nonstring='passthru')
search_regex = module.params['search_regex']
b_search_regex = to_bytes(search_regex, errors='surrogate_or_strict', nonstring='passthru')
msg = module.params['msg']
if search_regex is not None:
try:
b_compiled_search_re = re.compile(b_search_regex, re.MULTILINE)
except re.error as e:
module.fail_json(msg="Invalid regular expression: %s" % e)
else:
b_compiled_search_re = None
match_groupdict = {}
match_groups = ()
if port and path:
module.fail_json(msg="port and path parameter can not both be passed to wait_for", elapsed=0)
if path and state == 'stopped':
module.fail_json(msg="state=stopped should only be used for checking a port in the wait_for module", elapsed=0)
if path and state == 'drained':
module.fail_json(msg="state=drained should only be used for checking a port in the wait_for module", elapsed=0)
if module.params['exclude_hosts'] is not None and state != 'drained':
module.fail_json(msg="exclude_hosts should only be with state=drained", elapsed=0)
for _connection_state in module.params['active_connection_states']:
try:
get_connection_state_id(_connection_state)
except Exception:
module.fail_json(msg="unknown active_connection_state (%s) defined" % _connection_state, elapsed=0)
start = datetime.now(timezone.utc)
if delay:
time.sleep(delay)
if not port and not path and state != 'drained':
time.sleep(timeout)
elif state in ['absent', 'stopped']:
# first wait for the stop condition
end = start + timedelta(seconds=timeout)
while datetime.now(timezone.utc) < end:
if path:
try:
if not os.access(b_path, os.F_OK):
break
except OSError:
break
elif port:
try:
s = socket.create_connection((host, port), connect_timeout)
s.shutdown(socket.SHUT_RDWR)
s.close()
except Exception:
break
# Conditions not yet met, wait and try again
time.sleep(module.params['sleep'])
else:
elapsed = datetime.now(timezone.utc) - start
if port:
module.fail_json(msg=msg or "Timeout when waiting for %s:%s to stop." % (host, port), elapsed=elapsed.seconds)
elif path:
module.fail_json(msg=msg or "Timeout when waiting for %s to be absent." % (path), elapsed=elapsed.seconds)
elif state in ['started', 'present']:
# wait for start condition
end = start + timedelta(seconds=timeout)
while datetime.now(timezone.utc) < end:
if path:
try:
os.stat(b_path)
except OSError as e:
# If anything except file not present, throw an error
if e.errno != 2:
elapsed = datetime.now(timezone.utc) - start
module.fail_json(msg=msg or "Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
# file doesn't exist yet, so continue
else:
# File exists. Are there additional things to check?
if not b_compiled_search_re:
# nope, succeed!
break
try:
with open(b_path, 'rb') as f:
try:
with contextlib.closing(mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)) as mm:
search = b_compiled_search_re.search(mm)
if search:
if search.groupdict():
match_groupdict = search.groupdict()
if search.groups():
match_groups = search.groups()
break
except (ValueError, OSError) as e:
module.debug('wait_for failed to use mmap on "%s": %s. Falling back to file read().' % (path, to_native(e)))
# cannot mmap this file, try normal read
search = re.search(b_compiled_search_re, f.read())
if search:
if search.groupdict():
match_groupdict = search.groupdict()
if search.groups():
match_groups = search.groups()
break
except Exception as e:
module.warn('wait_for failed on "%s", unexpected exception(%s): %s.).' % (path, to_native(e.__class__), to_native(e)))
except OSError:
pass
elif port:
alt_connect_timeout = math.ceil(
_timedelta_total_seconds(end - datetime.now(timezone.utc)),
)
try:
s = socket.create_connection((host, int(port)), min(connect_timeout, alt_connect_timeout))
except Exception:
# Failed to connect by connect_timeout. wait and try again
pass
else:
# Connected -- are there additional conditions?
if b_compiled_search_re:
b_data = b''
matched = False
while datetime.now(timezone.utc) < end:
max_timeout = math.ceil(
_timedelta_total_seconds(
end - datetime.now(timezone.utc),
),
)
readable = select.select([s], [], [], max_timeout)[0]
if not readable:
# No new data. Probably means our timeout
# expired
continue
response = s.recv(1024)
if not response:
# Server shutdown
break
b_data += response
if b_compiled_search_re.search(b_data):
matched = True
break
# Shutdown the client socket
try:
s.shutdown(socket.SHUT_RDWR)
except OSError as ex:
if ex.errno != errno.ENOTCONN:
raise
# else, the server broke the connection on its end, assume it's not ready
else:
s.close()
if matched:
# Found our string, success!
break
else:
# Connection established, success!
try:
s.shutdown(socket.SHUT_RDWR)
except OSError as ex:
if ex.errno != errno.ENOTCONN:
raise
# else, the server broke the connection on its end, assume it's not ready
else:
s.close()
break
# Conditions not yet met, wait and try again
time.sleep(module.params['sleep'])
else: # while-else
# Timeout expired
elapsed = datetime.now(timezone.utc) - start
if port:
if search_regex:
module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s:%s" % (search_regex, host, port), elapsed=elapsed.seconds)
else:
module.fail_json(msg=msg or "Timeout when waiting for %s:%s" % (host, port), elapsed=elapsed.seconds)
elif path:
if search_regex:
module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s" % (search_regex, path), elapsed=elapsed.seconds)
else:
module.fail_json(msg=msg or "Timeout when waiting for file %s" % (path), elapsed=elapsed.seconds)
elif state == 'drained':
# wait until all active connections are gone
end = start + timedelta(seconds=timeout)
tcpconns = TCPConnectionInfo(module)
while datetime.now(timezone.utc) < end:
if tcpconns.get_active_connections_count() == 0:
break
# Conditions not yet met, wait and try again
time.sleep(module.params['sleep'])
else:
elapsed = datetime.now(timezone.utc) - start
module.fail_json(msg=msg or "Timeout when waiting for %s:%s to drain" % (host, port), elapsed=elapsed.seconds)
elapsed = datetime.now(timezone.utc) - start
module.exit_json(state=state, port=port, search_regex=search_regex, match_groups=match_groups, match_groupdict=match_groupdict, path=path,
elapsed=elapsed.seconds)
if __name__ == '__main__':
main()