# pyre-ignore-all-errors[14, 15, 16]
from __future__ import annotations
import random
from collections.abc import Callable, Generator, Iterable
from itertools import chain
from typing import TYPE_CHECKING, Any
from sc2.ids.unit_typeid import UnitTypeId
from sc2.position import Point2
from sc2.unit import Unit
if TYPE_CHECKING:
from sc2.bot_ai import BotAI
[docs]
class Units(list):
"""A collection of Unit objects. Makes it easy to select units by selectors."""
@classmethod
def from_proto(cls, units, bot_object: BotAI) -> Units:
return cls((Unit(raw_unit, bot_object=bot_object) for raw_unit in units), bot_object)
def __init__(self, units: Iterable[Unit], bot_object: BotAI) -> None:
"""
:param units:
:param bot_object:
"""
super().__init__(units)
self._bot_object = bot_object
def __call__(self, unit_types: UnitTypeId | Iterable[UnitTypeId]) -> Units:
"""Creates a new mutable Units object from Units or list object.
:param unit_types:
"""
return self.of_type(unit_types)
def __iter__(self) -> Generator[Unit, None, None]:
return (item for item in super().__iter__())
[docs]
def copy(self) -> Units:
"""Creates a new mutable Units object from Units or list object.
:param units:
"""
return Units(self, self._bot_object)
def __or__(self, other: Units) -> Units:
"""
:param other:
"""
return Units(
chain(
iter(self),
(other_unit for other_unit in other if other_unit.tag not in (self_unit.tag for self_unit in self)),
),
self._bot_object,
)
def __add__(self, other: Units) -> Units:
"""
:param other:
"""
return Units(
chain(
iter(self),
(other_unit for other_unit in other if other_unit.tag not in (self_unit.tag for self_unit in self)),
),
self._bot_object,
)
def __and__(self, other: Units) -> Units:
"""
:param other:
"""
return Units(
(other_unit for other_unit in other if other_unit.tag in (self_unit.tag for self_unit in self)),
self._bot_object,
)
def __sub__(self, other: Units) -> Units:
"""
:param other:
"""
return Units(
(self_unit for self_unit in self if self_unit.tag not in (other_unit.tag for other_unit in other)),
self._bot_object,
)
def __hash__(self) -> int:
return hash(unit.tag for unit in self)
@property
def amount(self) -> int:
return len(self)
@property
def empty(self) -> bool:
return not bool(self)
@property
def exists(self) -> bool:
return bool(self)
[docs]
def find_by_tag(self, tag: int) -> Unit | None:
"""
:param tag:
"""
for unit in self:
if unit.tag == tag:
return unit
return None
[docs]
def by_tag(self, tag: int) -> Unit:
"""
:param tag:
"""
unit = self.find_by_tag(tag)
if unit is None:
raise KeyError("Unit not found")
return unit
@property
def first(self) -> Unit:
assert self, "Units object is empty"
return self[0]
[docs]
def take(self, n: int) -> Units:
"""
:param n:
"""
if n >= self.amount:
return self
return self.subgroup(self[:n])
@property
def random(self) -> Unit:
assert self, "Units object is empty"
return random.choice(self)
def random_or(self, other: Any) -> Unit:
return random.choice(self) if self else other
[docs]
def random_group_of(self, n: int) -> Units:
"""Returns self if n >= self.amount."""
if n < 1:
return Units([], self._bot_object)
if n >= self.amount:
return self
return self.subgroup(random.sample(self, n))
[docs]
def in_attack_range_of(self, unit: Unit, bonus_distance: float = 0) -> Units:
"""Filters units that are in attack range of the given unit.
This uses the unit and target unit.radius when calculating the distance, so it should be accurate.
Caution: This may not work well for static structures (bunker, sieged tank, planetary fortress, photon cannon, spine and spore crawler) because it seems attack ranges differ for static / immovable units.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
all_zerglings_my_marine_can_attack = enemy_zerglings.in_attack_range_of(my_marine)
Example::
enemy_mutalisks = self.enemy_units(UnitTypeId.MUTALISK)
my_marauder = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARAUDER), None)
if my_marauder:
all_mutalisks_my_marauder_can_attack = enemy_mutaliskss.in_attack_range_of(my_marauder)
# Is empty because mutalisk are flying and marauder cannot attack air
:param unit:
:param bonus_distance:
"""
return self.filter(lambda x: unit.target_in_range(x, bonus_distance=bonus_distance))
[docs]
def closest_distance_to(self, position: Unit | Point2) -> float:
"""Returns the distance between the closest unit from this group to the target unit.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
closest_zergling_distance = enemy_zerglings.closest_distance_to(my_marine)
# Contains the distance between the marine and the closest zergling
:param position:
"""
assert self, "Units object is empty"
if isinstance(position, Unit):
return min(self._bot_object._distance_squared_unit_to_unit(unit, position) for unit in self) ** 0.5
return min(self._bot_object._distance_units_to_pos(self, position))
[docs]
def furthest_distance_to(self, position: Unit | Point2) -> float:
"""Returns the distance between the furthest unit from this group to the target unit
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
furthest_zergling_distance = enemy_zerglings.furthest_distance_to(my_marine)
# Contains the distance between the marine and the furthest away zergling
:param position:
"""
assert self, "Units object is empty"
if isinstance(position, Unit):
return max(self._bot_object._distance_squared_unit_to_unit(unit, position) for unit in self) ** 0.5
return max(self._bot_object._distance_units_to_pos(self, position))
[docs]
def closest_to(self, position: Unit | Point2) -> Unit:
"""Returns the closest unit (from this Units object) to the target unit or position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
closest_zergling = enemy_zerglings.closest_to(my_marine)
# Contains the zergling that is closest to the target marine
:param position:
"""
assert self, "Units object is empty"
if isinstance(position, Unit):
return min(
(unit1 for unit1 in self),
key=lambda unit2: self._bot_object._distance_squared_unit_to_unit(unit2, position),
)
distances = self._bot_object._distance_units_to_pos(self, position)
return min(((unit, dist) for unit, dist in zip(self, distances)), key=lambda my_tuple: my_tuple[1])[0]
[docs]
def furthest_to(self, position: Unit | Point2) -> Unit:
"""Returns the furhest unit (from this Units object) to the target unit or position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
furthest_zergling = enemy_zerglings.furthest_to(my_marine)
# Contains the zergling that is furthest away to the target marine
:param position:
"""
assert self, "Units object is empty"
if isinstance(position, Unit):
return max(
(unit1 for unit1 in self),
key=lambda unit2: self._bot_object._distance_squared_unit_to_unit(unit2, position),
)
distances = self._bot_object._distance_units_to_pos(self, position)
return max(((unit, dist) for unit, dist in zip(self, distances)), key=lambda my_tuple: my_tuple[1])[0]
[docs]
def closer_than(self, distance: float, position: Unit | Point2) -> Units:
"""Returns all units (from this Units object) that are closer than 'distance' away from target unit or position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
close_zerglings = enemy_zerglings.closer_than(3, my_marine)
# Contains all zerglings that are distance 3 or less away from the marine (does not include unit radius in calculation)
:param distance:
:param position:
"""
if not self:
return self
if isinstance(position, Unit):
distance_squared = distance**2
return self.subgroup(
unit
for unit in self
if self._bot_object._distance_squared_unit_to_unit(unit, position) < distance_squared
)
distances = self._bot_object._distance_units_to_pos(self, position)
return self.subgroup(unit for unit, dist in zip(self, distances) if dist < distance)
[docs]
def further_than(self, distance: float, position: Unit | Point2) -> Units:
"""Returns all units (from this Units object) that are further than 'distance' away from target unit or position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
far_zerglings = enemy_zerglings.further_than(3, my_marine)
# Contains all zerglings that are distance 3 or more away from the marine (does not include unit radius in calculation)
:param distance:
:param position:
"""
if not self:
return self
if isinstance(position, Unit):
distance_squared = distance**2
return self.subgroup(
unit
for unit in self
if distance_squared < self._bot_object._distance_squared_unit_to_unit(unit, position)
)
distances = self._bot_object._distance_units_to_pos(self, position)
return self.subgroup(unit for unit, dist in zip(self, distances) if distance < dist)
[docs]
def in_distance_between(
self, position: Unit | Point2 | tuple[float, float], distance1: float, distance2: float
) -> Units:
"""Returns units that are further than distance1 and closer than distance2 to unit or position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
zerglings_filtered = enemy_zerglings.in_distance_between(my_marine, 3, 5)
# Contains all zerglings that are between distance 3 and 5 away from the marine (does not include unit radius in calculation)
:param position:
:param distance1:
:param distance2:
"""
if not self:
return self
if isinstance(position, Unit):
distance1_squared = distance1**2
distance2_squared = distance2**2
return self.subgroup(
unit
for unit in self
if distance1_squared
< self._bot_object._distance_squared_unit_to_unit(unit, position)
< distance2_squared
)
distances = self._bot_object._distance_units_to_pos(self, position)
return self.subgroup(unit for unit, dist in zip(self, distances) if distance1 < dist < distance2)
[docs]
def closest_n_units(self, position: Unit | Point2, n: int) -> Units:
"""Returns the n closest units in distance to position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
zerglings_filtered = enemy_zerglings.closest_n_units(my_marine, 5)
# Contains 5 zerglings that are the closest to the marine
:param position:
:param n:
"""
if not self:
return self
return self.subgroup(self._list_sorted_by_distance_to(position)[:n])
[docs]
def furthest_n_units(self, position: Unit | Point2, n: int) -> Units:
"""Returns the n furhest units in distance to position.
Example::
enemy_zerglings = self.enemy_units(UnitTypeId.ZERGLING)
my_marine = next((unit for unit in self.units if unit.type_id == UnitTypeId.MARINE), None)
if my_marine:
zerglings_filtered = enemy_zerglings.furthest_n_units(my_marine, 5)
# Contains 5 zerglings that are the furthest to the marine
:param position:
:param n:
"""
if not self:
return self
return self.subgroup(self._list_sorted_by_distance_to(position)[-n:])
[docs]
def in_distance_of_group(self, other_units: Units, distance: float) -> Units:
"""Returns units that are closer than distance from any unit in the other units object.
:param other_units:
:param distance:
"""
assert other_units, "Other units object is empty"
# Return self because there are no enemies
if not self:
return self
distance_squared = distance**2
if len(self) == 1:
if any(
self._bot_object._distance_squared_unit_to_unit(self[0], target) < distance_squared
for target in other_units
):
return self
return self.subgroup([])
return self.subgroup(
self_unit
for self_unit in self
if any(
self._bot_object._distance_squared_unit_to_unit(self_unit, other_unit) < distance_squared
for other_unit in other_units
)
)
[docs]
def in_closest_distance_to_group(self, other_units: Units) -> Unit:
"""Returns unit in shortest distance from any unit in self to any unit in group.
Loops over all units in self, then loops over all units in other_units and calculates the shortest distance. Returns the units that is closest to any unit of 'other_units'.
:param other_units:
"""
assert self, "Units object is empty"
assert other_units, "Given units object is empty"
return min(
self,
key=lambda self_unit: min(
self._bot_object._distance_squared_unit_to_unit(self_unit, other_unit) for other_unit in other_units
),
)
def _list_sorted_closest_to_distance(self, position: Unit | Point2, distance: float) -> list[Unit]:
"""This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position))
:param position:
:param distance:
"""
if isinstance(position, Unit):
return sorted(
self,
key=lambda unit: abs(self._bot_object._distance_squared_unit_to_unit(unit, position) - distance),
reverse=True,
)
distances = self._bot_object._distance_units_to_pos(self, position)
unit_dist_dict = {unit.tag: dist for unit, dist in zip(self, distances)}
return sorted(self, key=lambda unit2: abs(unit_dist_dict[unit2.tag] - distance), reverse=True)
[docs]
def n_closest_to_distance(self, position: Point2, distance: float, n: int) -> Units:
"""Returns n units that are the closest to distance away.
For example if the distance is set to 5 and you want 3 units, from units with distance [3, 4, 5, 6, 7] to position,
the units with distance [4, 5, 6] will be returned
:param position:
:param distance:
"""
return self.subgroup(self._list_sorted_closest_to_distance(position=position, distance=distance)[:n])
[docs]
def n_furthest_to_distance(self, position: Point2, distance: float, n: int) -> Units:
"""Inverse of the function 'n_closest_to_distance', returns the furthest units instead
:param position:
:param distance:
"""
return self.subgroup(self._list_sorted_closest_to_distance(position=position, distance=distance)[-n:])
[docs]
def subgroup(self, units: Iterable[Unit]) -> Units:
"""Creates a new mutable Units object from Units or list object.
:param units:
"""
return Units(units, self._bot_object)
[docs]
def filter(self, pred: Callable[[Unit], Any]) -> Units:
"""Filters the current Units object and returns a new Units object.
Example::
from sc2.ids.unit_typeid import UnitTypeId
my_marines = self.units.filter(lambda unit: unit.type_id == UnitTypeId.MARINE)
completed_structures = self.structures.filter(lambda structure: structure.is_ready)
queens_with_energy_to_inject = self.units.filter(lambda unit: unit.type_id == UnitTypeId.QUEEN and unit.energy >= 25)
orbitals_with_energy_to_mule = self.structures.filter(lambda structure: structure.type_id == UnitTypeId.ORBITALCOMMAND and structure.energy >= 50)
my_units_that_can_shoot_up = self.units.filter(lambda unit: unit.can_attack_air)
See more unit properties in unit.py
:param pred:
"""
assert callable(pred), "Function is not callable"
return self.subgroup(filter(pred, self))
def sorted(self, key: Callable[[Unit], Any], reverse: bool = False) -> Units:
return self.subgroup(sorted(self, key=key, reverse=reverse))
def _list_sorted_by_distance_to(self, position: Unit | Point2, reverse: bool = False) -> list[Unit]:
"""This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position))
:param position:
:param reverse:
"""
if isinstance(position, Unit):
return sorted(
self, key=lambda unit: self._bot_object._distance_squared_unit_to_unit(unit, position), reverse=reverse
)
distances = self._bot_object._distance_units_to_pos(self, position)
unit_dist_dict = {unit.tag: dist for unit, dist in zip(self, distances)}
return sorted(self, key=lambda unit2: unit_dist_dict[unit2.tag], reverse=reverse)
[docs]
def sorted_by_distance_to(self, position: Unit | Point2, reverse: bool = False) -> Units:
"""This function should be a bit faster than using units.sorted(key=lambda u: u.distance_to(position))
:param position:
:param reverse:
"""
return self.subgroup(self._list_sorted_by_distance_to(position, reverse=reverse))
[docs]
def of_type(self, other: UnitTypeId | Iterable[UnitTypeId]) -> Units:
"""Filters all units that are of a specific type
Example::
# Use a set instead of lists in the argument
some_attack_units = self.units.of_type({ZERGLING, ROACH, HYDRALISK, BROODLORD})
:param other:
"""
if isinstance(other, UnitTypeId):
other = {other}
elif isinstance(other, list):
other = set(other)
return self.filter(lambda unit: unit.type_id in other)
[docs]
def exclude_type(self, other: UnitTypeId | Iterable[UnitTypeId]) -> Units:
"""Filters all units that are not of a specific type
Example::
# Use a set instead of lists in the argument
ignore_units = self.enemy_units.exclude_type({LARVA, EGG, OVERLORD})
:param other:
"""
if isinstance(other, UnitTypeId):
other = {other}
elif isinstance(other, list):
other = set(other)
return self.filter(lambda unit: unit.type_id not in other)
[docs]
def same_tech(self, other: set[UnitTypeId]) -> Units:
"""Returns all structures that have the same base structure.
Untested: This should return the equivalents for WarpPrism, Observer, Overseer, SupplyDepot and others
Example::
# All command centers, flying command centers, orbital commands, flying orbital commands, planetary fortress
terran_townhalls = self.townhalls.same_tech(UnitTypeId.COMMANDCENTER)
# All hatcheries, lairs and hives
zerg_townhalls = self.townhalls.same_tech({UnitTypeId.HATCHERY})
# All spires and greater spires
spires = self.townhalls.same_tech({UnitTypeId.SPIRE})
# The following returns the same
spires = self.townhalls.same_tech({UnitTypeId.GREATERSPIRE})
# This also works with multiple unit types
zerg_townhalls_and_spires = self.structures.same_tech({UnitTypeId.HATCHERY, UnitTypeId.SPIRE})
:param other:
"""
assert isinstance(other, set), (
"Please use a set as this filter function is already fairly slow. For example"
+ " 'self.units.same_tech({UnitTypeId.LAIR})'"
)
tech_alias_types: set[int] = {u.value for u in other}
unit_data = self._bot_object.game_data.units
for unit_type in other:
for same in unit_data[unit_type.value]._proto.tech_alias:
tech_alias_types.add(same)
return self.filter(
lambda unit: unit._proto.unit_type in tech_alias_types
or any(same in tech_alias_types for same in unit._type_data._proto.tech_alias)
)
[docs]
def same_unit(self, other: UnitTypeId | Iterable[UnitTypeId]) -> Units:
"""Returns all units that have the same base unit while being in different modes.
Untested: This should return the equivalents for WarpPrism, Observer, Overseer, SupplyDepot and other units that have different modes but still act as the same unit
Example::
# All command centers on the ground and flying
ccs = self.townhalls.same_unit(UnitTypeId.COMMANDCENTER)
# All orbital commands on the ground and flying
ocs = self.townhalls.same_unit(UnitTypeId.ORBITALCOMMAND)
# All roaches and burrowed roaches
roaches = self.units.same_unit(UnitTypeId.ROACH)
# This is useful because roach has a different type id when burrowed
burrowed_roaches = self.units(UnitTypeId.ROACHBURROWED)
:param other:
"""
if isinstance(other, UnitTypeId):
other = {other}
unit_alias_types: set[int] = {u.value for u in other}
unit_data = self._bot_object.game_data.units
for unit_type in other:
unit_alias_types.add(unit_data[unit_type.value]._proto.unit_alias)
unit_alias_types.discard(0)
return self.filter(
lambda unit: unit._proto.unit_type in unit_alias_types
or unit._type_data._proto.unit_alias in unit_alias_types
)
@property
def center(self) -> Point2:
"""Returns the central position of all units."""
assert self, "Units object is empty"
return Point2(
(
sum(unit._proto.pos.x for unit in self) / self.amount,
sum(unit._proto.pos.y for unit in self) / self.amount,
)
)
@property
def selected(self) -> Units:
"""Returns all units that are selected by the human player."""
return self.filter(lambda unit: unit.is_selected)
@property
def tags(self) -> set[int]:
"""Returns all unit tags as a set."""
return {unit.tag for unit in self}
@property
def ready(self) -> Units:
"""Returns all structures that are ready (construction complete)."""
return self.filter(lambda unit: unit.is_ready)
@property
def not_ready(self) -> Units:
"""Returns all structures that are not ready (construction not complete)."""
return self.filter(lambda unit: not unit.is_ready)
@property
def idle(self) -> Units:
"""Returns all units or structures that are doing nothing (unit is standing still, structure is doing nothing)."""
return self.filter(lambda unit: unit.is_idle)
@property
def owned(self) -> Units:
"""Deprecated: All your units."""
return self.filter(lambda unit: unit.is_mine)
@property
def enemy(self) -> Units:
"""Deprecated: All enemy units."""
return self.filter(lambda unit: unit.is_enemy)
@property
def flying(self) -> Units:
"""Returns all units that are flying."""
return self.filter(lambda unit: unit.is_flying)
@property
def not_flying(self) -> Units:
"""Returns all units that not are flying."""
return self.filter(lambda unit: not unit.is_flying)
@property
def structure(self) -> Units:
"""Deprecated: All structures."""
return self.filter(lambda unit: unit.is_structure)
@property
def not_structure(self) -> Units:
"""Deprecated: All units that are not structures."""
return self.filter(lambda unit: not unit.is_structure)
@property
def gathering(self) -> Units:
"""Returns all workers that are mining minerals or vespene (gather command)."""
return self.filter(lambda unit: unit.is_gathering)
@property
def returning(self) -> Units:
"""Returns all workers that are carrying minerals or vespene and are returning to a townhall."""
return self.filter(lambda unit: unit.is_returning)
@property
def collecting(self) -> Units:
"""Returns all workers that are mining or returning resources."""
return self.filter(lambda unit: unit.is_collecting)
@property
def visible(self) -> Units:
"""Returns all units or structures that are visible.
TODO: add proper description on which units are exactly visible (not snapshots?)"""
return self.filter(lambda unit: unit.is_visible)
@property
def mineral_field(self) -> Units:
"""Returns all units that are mineral fields."""
return self.filter(lambda unit: unit.is_mineral_field)
@property
def vespene_geyser(self) -> Units:
"""Returns all units that are vespene geysers."""
return self.filter(lambda unit: unit.is_vespene_geyser)
@property
def prefer_idle(self) -> Units:
"""Sorts units based on if they are idle. Idle units come first."""
return self.sorted(lambda unit: unit.is_idle, reverse=True)