Source code for sc2.bot_ai_internal

# pyre-ignore-all-errors[6, 16, 29]
from __future__ import annotations

import itertools
import math
import time
import warnings
from abc import ABC
from collections import Counter
from collections.abc import Generator, Iterable
from contextlib import suppress
from typing import TYPE_CHECKING, Any, final

import numpy as np
from loguru import logger

# pyre-ignore[21]
from s2clientprotocol import sc2api_pb2 as sc_pb

from sc2.cache import property_cache_once_per_frame
from sc2.constants import (
    ALL_GAS,
    CREATION_ABILITY_FIX,
    IS_PLACEHOLDER,
    TERRAN_STRUCTURES_REQUIRE_SCV,
    FakeEffectID,
    abilityid_to_unittypeid,
    geyser_ids,
    mineral_ids,
)
from sc2.data import ActionResult, Race, race_townhalls
from sc2.game_data import Cost, GameData
from sc2.game_state import Blip, EffectData, GameState
from sc2.ids.ability_id import AbilityId
from sc2.ids.unit_typeid import UnitTypeId
from sc2.ids.upgrade_id import UpgradeId
from sc2.pixel_map import PixelMap
from sc2.position import Point2
from sc2.unit import Unit
from sc2.unit_command import UnitCommand
from sc2.units import Units

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    # pyre-ignore[21]
    from scipy.spatial.distance import cdist, pdist

if TYPE_CHECKING:
    from sc2.client import Client
    from sc2.game_info import GameInfo


[docs] class BotAIInternal(ABC): """Base class for bots.""" def __init__(self) -> None: self._initialize_variables() @final def _initialize_variables(self) -> None: """Called from main.py internally""" self.cache: dict[str, Any] = {} # Specific opponent bot ID used in sc2ai ladder games http://sc2ai.net/ and on ai arena https://aiarena.net # The bot ID will stay the same each game so your bot can "adapt" to the opponent if not hasattr(self, "opponent_id"): # Prevent overwriting the opponent_id which is set here https://github.com/Hannessa/python-sc2-ladderbot/blob/master/__init__.py#L40 # otherwise set it to None self.opponent_id: str | None = None # Select distance calculation method, see _distances_override_functions function if not hasattr(self, "distance_calculation_method"): self.distance_calculation_method: int = 2 # Select if the Unit.command should return UnitCommand objects. Set this to True if your bot uses 'self.do(unit(ability, target))' if not hasattr(self, "unit_command_uses_self_do"): self.unit_command_uses_self_do: bool = False # This value will be set to True by main.py in self._prepare_start if game is played in realtime (if true, the bot will have limited time per step) self.realtime: bool = False self.base_build: int = -1 self.all_units: Units = Units([], self) self.units: Units = Units([], self) self.workers: Units = Units([], self) self.larva: Units = Units([], self) self.structures: Units = Units([], self) self.townhalls: Units = Units([], self) self.gas_buildings: Units = Units([], self) self.all_own_units: Units = Units([], self) self.enemy_units: Units = Units([], self) self.enemy_structures: Units = Units([], self) self.all_enemy_units: Units = Units([], self) self.resources: Units = Units([], self) self.destructables: Units = Units([], self) self.watchtowers: Units = Units([], self) self.mineral_field: Units = Units([], self) self.vespene_geyser: Units = Units([], self) self.placeholders: Units = Units([], self) self.techlab_tags: set[int] = set() self.reactor_tags: set[int] = set() self.minerals: int = 50 self.vespene: int = 0 self.supply_army: float = 0 self.supply_workers: float = 12 # Doesn't include workers in production self.supply_cap: float = 15 self.supply_used: float = 12 self.supply_left: float = 3 self.idle_worker_count: int = 0 self.army_count: int = 0 self.warp_gate_count: int = 0 self.actions: list[UnitCommand] = [] self.blips: set[Blip] = set() # pyre-ignore[11] self.race: Race | None = None self.enemy_race: Race | None = None self._generated_frame = -100 self._units_created: Counter = Counter() self._unit_tags_seen_this_game: set[int] = set() self._units_previous_map: dict[int, Unit] = {} self._structures_previous_map: dict[int, Unit] = {} self._enemy_units_previous_map: dict[int, Unit] = {} self._enemy_structures_previous_map: dict[int, Unit] = {} self._all_units_previous_map: dict[int, Unit] = {} self._previous_upgrades: set[UpgradeId] = set() self._expansion_positions_list: list[Point2] = [] self._resource_location_to_expansion_position_dict: dict[Point2, Point2] = {} self._time_before_step: float = 0 self._time_after_step: float = 0 self._min_step_time: float = math.inf self._max_step_time: float = 0 self._last_step_step_time: float = 0 self._total_time_in_on_step: float = 0 self._total_steps_iterations: int = 0 # Internally used to keep track which units received an action in this frame, so that self.train() function does not give the same larva two orders - cleared every frame self.unit_tags_received_action: set[int] = set() @final @property def _game_info(self) -> GameInfo: """See game_info.py""" warnings.warn( "Using self._game_info is deprecated and may be removed soon. Please use self.game_info directly.", DeprecationWarning, stacklevel=2, ) return self.game_info @final @property def _game_data(self) -> GameData: """See game_data.py""" warnings.warn( "Using self._game_data is deprecated and may be removed soon. Please use self.game_data directly.", DeprecationWarning, stacklevel=2, ) return self.game_data @final @property def _client(self) -> Client: """See client.py""" warnings.warn( "Using self._client is deprecated and may be removed soon. Please use self.client directly.", DeprecationWarning, stacklevel=2, ) return self.client @final @property_cache_once_per_frame def expansion_locations(self) -> dict[Point2, Units]: """Same as the function above.""" assert self._expansion_positions_list, "self._find_expansion_locations() has not been run yet, so accessing the list of expansion locations is pointless." warnings.warn( "You are using 'self.expansion_locations', please use 'self.expansion_locations_list' (fast) or 'self.expansion_locations_dict' (slow) instead.", DeprecationWarning, stacklevel=2, ) return self.expansion_locations_dict @final def _find_expansion_locations(self) -> None: """Ran once at the start of the game to calculate expansion locations.""" # Idea: create a group for every resource, then merge these groups if # any resource in a group is closer than a threshold to any resource of another group # Distance we group resources by resource_spread_threshold: float = 8.5 # Create a group for every resource resource_groups: list[list[Unit]] = [ [resource] for resource in self.resources if resource.name != "MineralField450" # dont use low mineral count patches ] # Loop the merging process as long as we change something merged_group = True height_grid: PixelMap = self.game_info.terrain_height while merged_group: merged_group = False # Check every combination of two groups for group_a, group_b in itertools.combinations(resource_groups, 2): # Check if any pair of resource of these groups is closer than threshold together # And that they are on the same terrain level if any( resource_a.distance_to(resource_b) <= resource_spread_threshold # check if terrain height measurement at resources is within 10 units # this is since some older maps have inconsistent terrain height # tiles at certain expansion locations and abs(height_grid[resource_a.position.rounded] - height_grid[resource_b.position.rounded]) <= 10 for resource_a, resource_b in itertools.product(group_a, group_b) ): # Remove the single groups and add the merged group resource_groups.remove(group_a) resource_groups.remove(group_b) resource_groups.append(group_a + group_b) merged_group = True break # Distance offsets we apply to center of each resource group to find expansion position offset_range = 7 offsets = [ (x, y) for x, y in itertools.product(range(-offset_range, offset_range + 1), repeat=2) if 4 < math.hypot(x, y) <= 8 ] # Dict we want to return centers = {} # For every resource group: for resources in resource_groups: # Possible expansion points amount = len(resources) # Calculate center, round and add 0.5 because expansion location will have (x.5, y.5) # coordinates because bases have size 5. center_x = int(sum(resource.position.x for resource in resources) / amount) + 0.5 center_y = int(sum(resource.position.y for resource in resources) / amount) + 0.5 possible_points = (Point2((offset[0] + center_x, offset[1] + center_y)) for offset in offsets) # Filter out points that are too near possible_points = ( point for point in possible_points # Check if point can be built on if self.game_info.placement_grid[point.rounded] == 1 # Check if all resources have enough space to point and all( point.distance_to(resource) >= (7 if resource._proto.unit_type in geyser_ids else 6) for resource in resources ) ) # Choose best fitting point result: Point2 = min( possible_points, key=lambda point: sum(point.distance_to(resource_) for resource_ in resources) ) centers[result] = resources # Put all expansion locations in a list self._expansion_positions_list.append(result) # Maps all resource positions to the expansion position for resource in resources: self._resource_location_to_expansion_position_dict[resource.position] = result @final def _correct_zerg_supply(self) -> None: """The client incorrectly rounds zerg supply down instead of up (see https://github.com/Blizzard/s2client-proto/issues/123), so self.supply_used and friends return the wrong value when there are an odd number of zerglings and banelings. This function corrects the bad values.""" # TODO: remove when Blizzard/sc2client-proto#123 gets fixed. half_supply_units = { UnitTypeId.ZERGLING, UnitTypeId.ZERGLINGBURROWED, UnitTypeId.BANELING, UnitTypeId.BANELINGBURROWED, UnitTypeId.BANELINGCOCOON, } correction = self.units(half_supply_units).amount % 2 self.supply_used += correction self.supply_army += correction self.supply_left -= correction @final @property_cache_once_per_frame def _abilities_count_and_build_progress(self) -> tuple[Counter[AbilityId], dict[AbilityId, float]]: """Cache for the already_pending function, includes protoss units warping in, all units in production and all structures, and all morphs""" abilities_amount: Counter[AbilityId] = Counter() max_build_progress: dict[AbilityId, float] = {} unit: Unit for unit in self.units + self.structures: for order in unit.orders: abilities_amount[order.ability.exact_id] += 1 if not unit.is_ready and (self.race != Race.Terran or not unit.is_structure): # If an SCV is constructing a building, already_pending would count this structure twice # (once from the SCV order, and once from "not structure.is_ready") if unit.type_id in CREATION_ABILITY_FIX: if unit.type_id == UnitTypeId.ARCHON: # Hotfix for archons in morph state creation_ability = AbilityId.ARCHON_WARP_TARGET abilities_amount[creation_ability] += 2 else: # Hotfix for rich geysirs creation_ability = CREATION_ABILITY_FIX[unit.type_id] abilities_amount[creation_ability] += 1 else: creation_ability: AbilityId = self.game_data.units[unit.type_id.value].creation_ability.exact_id abilities_amount[creation_ability] += 1 max_build_progress[creation_ability] = max( max_build_progress.get(creation_ability, 0), unit.build_progress ) return abilities_amount, max_build_progress @final @property_cache_once_per_frame def _worker_orders(self) -> Counter[AbilityId]: """This function is used internally, do not use! It is to store all worker abilities.""" abilities_amount: Counter[AbilityId] = Counter() structures_in_production: set[Point2 | int] = set() for structure in self.structures: if structure.type_id in TERRAN_STRUCTURES_REQUIRE_SCV: structures_in_production.add(structure.position) structures_in_production.add(structure.tag) for worker in self.workers: for order in worker.orders: # Skip if the SCV is constructing (not isinstance(order.target, int)) # or resuming construction (isinstance(order.target, int)) if order.target in structures_in_production: continue abilities_amount[order.ability.exact_id] += 1 return abilities_amount
[docs] @final def do( self, action: UnitCommand, subtract_cost: bool = False, subtract_supply: bool = False, can_afford_check: bool = False, ignore_warning: bool = False, ) -> bool: """Adds a unit action to the 'self.actions' list which is then executed at the end of the frame. Training a unit:: # Train an SCV from a random idle command center cc = self.townhalls.idle.random_or(None) # self.townhalls can be empty or there are no idle townhalls if cc and self.can_afford(UnitTypeId.SCV): cc.train(UnitTypeId.SCV) Building a building:: # Building a barracks at the main ramp, requires 150 minerals and a depot worker = self.workers.random_or(None) barracks_placement_position = self.main_base_ramp.barracks_correct_placement if worker and self.can_afford(UnitTypeId.BARRACKS): worker.build(UnitTypeId.BARRACKS, barracks_placement_position) Moving a unit:: # Move a random worker to the center of the map worker = self.workers.random_or(None) # worker can be None if all are dead if worker: worker.move(self.game_info.map_center) :param action: :param subtract_cost: :param subtract_supply: :param can_afford_check: """ if not self.unit_command_uses_self_do and isinstance(action, bool): if not ignore_warning: warnings.warn( "You have used self.do(). Please consider putting 'self.unit_command_uses_self_do = True' in your bot __init__() function or removing self.do().", DeprecationWarning, stacklevel=2, ) return action assert isinstance( action, UnitCommand ), f"Given unit command is not a command, but instead of type {type(action)}" if subtract_cost: cost: Cost = self.game_data.calculate_ability_cost(action.ability) if can_afford_check and not (self.minerals >= cost.minerals and self.vespene >= cost.vespene): # Dont do action if can't afford return False self.minerals -= cost.minerals self.vespene -= cost.vespene if subtract_supply and action.ability in abilityid_to_unittypeid: unit_type = abilityid_to_unittypeid[action.ability] required_supply = self.calculate_supply_cost(unit_type) # Overlord has -8 if required_supply > 0: self.supply_used += required_supply self.supply_left -= required_supply self.actions.append(action) self.unit_tags_received_action.add(action.unit.tag) return True
[docs] @final async def synchronous_do(self, action: UnitCommand): """ Not recommended. Use self.do instead to reduce lag. This function is only useful for realtime=True in the first frame of the game to instantly produce a worker and split workers on the mineral patches. """ assert isinstance( action, UnitCommand ), f"Given unit command is not a command, but instead of type {type(action)}" if not self.can_afford(action.ability): logger.warning(f"Cannot afford action {action}") return ActionResult.Error r = await self.client.actions(action) if not r: # success cost = self.game_data.calculate_ability_cost(action.ability) self.minerals -= cost.minerals self.vespene -= cost.vespene self.unit_tags_received_action.add(action.unit.tag) else: logger.error(f"Error: {r} (action: {action})") return r
@final async def _do_actions(self, actions: list[UnitCommand], prevent_double: bool = True): """Used internally by main.py automatically, use self.do() instead! :param actions: :param prevent_double:""" if not actions: return None if prevent_double: actions = list(filter(self.prevent_double_actions, actions)) result = await self.client.actions(actions) return result
[docs] @final @staticmethod def prevent_double_actions(action) -> bool: """ :param action: """ # Always add actions if queued if action.queue: return True if action.unit.orders: # action: UnitCommand # current_action: UnitOrder current_action = action.unit.orders[0] if action.ability not in {current_action.ability.id, current_action.ability.exact_id}: # Different action, return True return True with suppress(AttributeError): if current_action.target == action.target.tag: # Same action, remove action if same target unit return False with suppress(AttributeError): if action.target.x == current_action.target.x and action.target.y == current_action.target.y: # Same action, remove action if same target position return False return True return True
@final def _prepare_start( self, client, player_id: int, game_info, game_data, realtime: bool = False, base_build: int = -1 ) -> None: """ Ran until game start to set game and player data. :param client: :param player_id: :param game_info: :param game_data: :param realtime: """ self.client: Client = client self.player_id: int = player_id self.game_info: GameInfo = game_info self.game_data: GameData = game_data self.realtime: bool = realtime self.base_build: int = base_build self.race: Race = Race(self.game_info.player_races[self.player_id]) if len(self.game_info.player_races) == 2: self.enemy_race: Race = Race(self.game_info.player_races[3 - self.player_id]) self._distances_override_functions(self.distance_calculation_method) @final def _prepare_first_step(self) -> None: """First step extra preparations. Must not be called before _prepare_step.""" if self.townhalls: self.game_info.player_start_location = self.townhalls.first.position # Calculate and cache expansion locations forever inside 'self._cache_expansion_locations', this is done to prevent a bug when this is run and cached later in the game self._find_expansion_locations() self.game_info.map_ramps, self.game_info.vision_blockers = self.game_info._find_ramps_and_vision_blockers() self._time_before_step: float = time.perf_counter() @final def _prepare_step(self, state, proto_game_info) -> None: """ :param state: :param proto_game_info: """ # Set attributes from new state before on_step.""" self.state: GameState = state # See game_state.py # update pathing grid, which unfortunately is in GameInfo instead of GameState self.game_info.pathing_grid = PixelMap(proto_game_info.game_info.start_raw.pathing_grid, in_bits=True) # Required for events, needs to be before self.units are initialized so the old units are stored self._units_previous_map: dict[int, Unit] = {unit.tag: unit for unit in self.units} self._structures_previous_map: dict[int, Unit] = {structure.tag: structure for structure in self.structures} self._enemy_units_previous_map: dict[int, Unit] = {unit.tag: unit for unit in self.enemy_units} self._enemy_structures_previous_map: dict[int, Unit] = { structure.tag: structure for structure in self.enemy_structures } self._all_units_previous_map: dict[int, Unit] = {unit.tag: unit for unit in self.all_units} self._prepare_units() self.minerals: int = state.common.minerals self.vespene: int = state.common.vespene self.supply_army: int = state.common.food_army self.supply_workers: int = state.common.food_workers # Doesn't include workers in production self.supply_cap: int = state.common.food_cap self.supply_used: int = state.common.food_used self.supply_left: int = self.supply_cap - self.supply_used if self.race == Race.Zerg: # Workaround Zerg supply rounding bug self._correct_zerg_supply() elif self.race == Race.Protoss: self.warp_gate_count: int = state.common.warp_gate_count self.idle_worker_count: int = state.common.idle_worker_count self.army_count: int = state.common.army_count self._time_before_step: float = time.perf_counter() if self.enemy_race == Race.Random and self.all_enemy_units: self.enemy_race = Race(self.all_enemy_units.first.race) @final def _prepare_units(self) -> None: # Set of enemy units detected by own sensor tower, as blips have less unit information than normal visible units self.blips: set[Blip] = set() self.all_units: Units = Units([], self) self.units: Units = Units([], self) self.workers: Units = Units([], self) self.larva: Units = Units([], self) self.structures: Units = Units([], self) self.townhalls: Units = Units([], self) self.gas_buildings: Units = Units([], self) self.all_own_units: Units = Units([], self) self.enemy_units: Units = Units([], self) self.enemy_structures: Units = Units([], self) self.all_enemy_units: Units = Units([], self) self.resources: Units = Units([], self) self.destructables: Units = Units([], self) self.watchtowers: Units = Units([], self) self.mineral_field: Units = Units([], self) self.vespene_geyser: Units = Units([], self) self.placeholders: Units = Units([], self) self.techlab_tags: set[int] = set() self.reactor_tags: set[int] = set() worker_types: set[UnitTypeId] = {UnitTypeId.DRONE, UnitTypeId.DRONEBURROWED, UnitTypeId.SCV, UnitTypeId.PROBE} index: int = 0 for unit in self.state.observation_raw.units: if unit.is_blip: self.blips.add(Blip(unit)) else: unit_type: int = unit.unit_type # Convert these units to effects: reaper grenade, parasitic bomb dummy, forcefield if unit_type in FakeEffectID: self.state.effects.add(EffectData(unit, fake=True)) continue unit_obj = Unit(unit, self, distance_calculation_index=index, base_build=self.base_build) index += 1 self.all_units.append(unit_obj) if unit.display_type == IS_PLACEHOLDER: self.placeholders.append(unit_obj) continue alliance = unit.alliance # Alliance.Neutral.value = 3 if alliance == 3: # XELNAGATOWER = 149 if unit_type == 149: self.watchtowers.append(unit_obj) # mineral field enums elif unit_type in mineral_ids: self.mineral_field.append(unit_obj) self.resources.append(unit_obj) # geyser enums elif unit_type in geyser_ids: self.vespene_geyser.append(unit_obj) self.resources.append(unit_obj) # all destructable rocks else: self.destructables.append(unit_obj) # Alliance.Self.value = 1 elif alliance == 1: self.all_own_units.append(unit_obj) unit_id: UnitTypeId = unit_obj.type_id if unit_obj.is_structure: self.structures.append(unit_obj) if unit_id in race_townhalls[self.race]: self.townhalls.append(unit_obj) elif unit_id in ALL_GAS or unit_obj.vespene_contents: # TODO: remove "or unit_obj.vespene_contents" when a new linux client newer than version 4.10.0 is released self.gas_buildings.append(unit_obj) elif unit_id in { UnitTypeId.TECHLAB, UnitTypeId.BARRACKSTECHLAB, UnitTypeId.FACTORYTECHLAB, UnitTypeId.STARPORTTECHLAB, }: self.techlab_tags.add(unit_obj.tag) elif unit_id in { UnitTypeId.REACTOR, UnitTypeId.BARRACKSREACTOR, UnitTypeId.FACTORYREACTOR, UnitTypeId.STARPORTREACTOR, }: self.reactor_tags.add(unit_obj.tag) else: self.units.append(unit_obj) if unit_id in worker_types: self.workers.append(unit_obj) elif unit_id == UnitTypeId.LARVA: self.larva.append(unit_obj) # Alliance.Enemy.value = 4 elif alliance == 4: self.all_enemy_units.append(unit_obj) if unit_obj.is_structure: self.enemy_structures.append(unit_obj) else: self.enemy_units.append(unit_obj) # Force distance calculation and caching on all units using scipy pdist or cdist if self.distance_calculation_method == 1: _ = self._pdist elif self.distance_calculation_method in {2, 3}: _ = self._cdist @final async def _after_step(self) -> int: """Executed by main.py after each on_step function.""" # Keep track of the bot on_step duration self._time_after_step: float = time.perf_counter() step_duration = self._time_after_step - self._time_before_step self._min_step_time = min(step_duration, self._min_step_time) self._max_step_time = max(step_duration, self._max_step_time) self._last_step_step_time = step_duration self._total_time_in_on_step += step_duration self._total_steps_iterations += 1 # Commit and clear bot actions if self.actions: await self._do_actions(self.actions) self.actions.clear() # Clear set of unit tags that were given an order this frame by self.do() self.unit_tags_received_action.clear() # Commit debug queries await self.client._send_debug() return self.state.game_loop @final async def _advance_steps(self, steps: int) -> None: """Advances the game loop by amount of 'steps'. This function is meant to be used as a debugging and testing tool only. If you are using this, please be aware of the consequences, e.g. 'self.units' will be filled with completely new data.""" await self._after_step() # Advance simulation by exactly "steps" frames await self.client.step(steps) state = await self.client.observation() gs = GameState(state.observation) proto_game_info = await self.client._execute(game_info=sc_pb.RequestGameInfo()) self._prepare_step(gs, proto_game_info) await self.issue_events()
[docs] @final async def issue_events(self) -> None: """This function will be automatically run from main.py and triggers the following functions: - on_unit_created - on_unit_destroyed - on_building_construction_started - on_building_construction_complete - on_upgrade_complete """ await self._issue_unit_dead_events() await self._issue_unit_added_events() await self._issue_building_events() await self._issue_upgrade_events() await self._issue_vision_events()
@final async def _issue_unit_added_events(self) -> None: for unit in self.units: if unit.tag not in self._units_previous_map and unit.tag not in self._unit_tags_seen_this_game: self._unit_tags_seen_this_game.add(unit.tag) self._units_created[unit.type_id] += 1 await self.on_unit_created(unit) elif unit.tag in self._units_previous_map: previous_frame_unit: Unit = self._units_previous_map[unit.tag] # Check if a unit took damage this frame and then trigger event if unit.health < previous_frame_unit.health or unit.shield < previous_frame_unit.shield: damage_amount = previous_frame_unit.health - unit.health + previous_frame_unit.shield - unit.shield await self.on_unit_took_damage(unit, damage_amount) # Check if a unit type has changed if previous_frame_unit.type_id != unit.type_id: await self.on_unit_type_changed(unit, previous_frame_unit.type_id) @final async def _issue_upgrade_events(self) -> None: difference = self.state.upgrades - self._previous_upgrades for upgrade_completed in difference: await self.on_upgrade_complete(upgrade_completed) self._previous_upgrades = self.state.upgrades @final async def _issue_building_events(self) -> None: for structure in self.structures: if structure.tag not in self._structures_previous_map: if structure.build_progress < 1: await self.on_building_construction_started(structure) else: # Include starting townhall self._units_created[structure.type_id] += 1 await self.on_building_construction_complete(structure) elif structure.tag in self._structures_previous_map: # Check if a structure took damage this frame and then trigger event previous_frame_structure: Unit = self._structures_previous_map[structure.tag] if ( structure.health < previous_frame_structure.health or structure.shield < previous_frame_structure.shield ): damage_amount = ( previous_frame_structure.health - structure.health + previous_frame_structure.shield - structure.shield ) await self.on_unit_took_damage(structure, damage_amount) # Check if a structure changed its type if previous_frame_structure.type_id != structure.type_id: await self.on_unit_type_changed(structure, previous_frame_structure.type_id) # Check if structure completed if structure.build_progress == 1 and previous_frame_structure.build_progress < 1: self._units_created[structure.type_id] += 1 await self.on_building_construction_complete(structure) @final async def _issue_vision_events(self) -> None: # Call events for enemy unit entered vision for enemy_unit in self.enemy_units: if enemy_unit.tag not in self._enemy_units_previous_map: await self.on_enemy_unit_entered_vision(enemy_unit) for enemy_structure in self.enemy_structures: if enemy_structure.tag not in self._enemy_structures_previous_map: await self.on_enemy_unit_entered_vision(enemy_structure) # Call events for enemy unit left vision enemy_units_left_vision: set[int] = set(self._enemy_units_previous_map) - self.enemy_units.tags for enemy_unit_tag in enemy_units_left_vision: await self.on_enemy_unit_left_vision(enemy_unit_tag) enemy_structures_left_vision: set[int] = set(self._enemy_structures_previous_map) - self.enemy_structures.tags for enemy_structure_tag in enemy_structures_left_vision: await self.on_enemy_unit_left_vision(enemy_structure_tag) @final async def _issue_unit_dead_events(self) -> None: for unit_tag in self.state.dead_units & set(self._all_units_previous_map): await self.on_unit_destroyed(unit_tag) # DISTANCE CALCULATION @final @property def _units_count(self) -> int: return len(self.all_units) @final @property def _pdist(self) -> np.ndarray: """As property, so it will be recalculated each time it is called, or return from cache if it is called multiple times in teh same game_loop.""" if self._generated_frame != self.state.game_loop: return self.calculate_distances() return self._cached_pdist @final @property def _cdist(self) -> np.ndarray: """As property, so it will be recalculated each time it is called, or return from cache if it is called multiple times in teh same game_loop.""" if self._generated_frame != self.state.game_loop: return self.calculate_distances() return self._cached_cdist @final def _calculate_distances_method1(self) -> np.ndarray: self._generated_frame = self.state.game_loop # Converts tuple [(1, 2), (3, 4)] to flat list like [1, 2, 3, 4] flat_positions = (coord for unit in self.all_units for coord in unit.position_tuple) # Converts to numpy array, then converts the flat array back to shape (n, 2): [[1, 2], [3, 4]] positions_array: np.ndarray = np.fromiter( flat_positions, dtype=float, count=2 * self._units_count, ).reshape((self._units_count, 2)) assert len(positions_array) == self._units_count # See performance benchmarks self._cached_pdist = pdist(positions_array, "sqeuclidean") return self._cached_pdist @final def _calculate_distances_method2(self) -> np.ndarray: self._generated_frame = self.state.game_loop # Converts tuple [(1, 2), (3, 4)] to flat list like [1, 2, 3, 4] flat_positions = (coord for unit in self.all_units for coord in unit.position_tuple) # Converts to numpy array, then converts the flat array back to shape (n, 2): [[1, 2], [3, 4]] positions_array: np.ndarray = np.fromiter( flat_positions, dtype=float, count=2 * self._units_count, ).reshape((self._units_count, 2)) assert len(positions_array) == self._units_count # See performance benchmarks self._cached_cdist = cdist(positions_array, positions_array, "sqeuclidean") return self._cached_cdist @final def _calculate_distances_method3(self) -> np.ndarray: """Nearly same as above, but without asserts""" self._generated_frame = self.state.game_loop flat_positions = (coord for unit in self.all_units for coord in unit.position_tuple) positions_array: np.ndarray = np.fromiter( flat_positions, dtype=float, count=2 * self._units_count, ).reshape((-1, 2)) # See performance benchmarks self._cached_cdist = cdist(positions_array, positions_array, "sqeuclidean") return self._cached_cdist # Helper functions @final def square_to_condensed(self, i, j) -> int: # Converts indices of a square matrix to condensed matrix # https://stackoverflow.com/a/36867493/10882657 assert i != j, "No diagonal elements in condensed matrix! Diagonal elements are zero" if i < j: i, j = j, i return self._units_count * j - j * (j + 1) // 2 + i - 1 - j
[docs] @final @staticmethod def convert_tuple_to_numpy_array(pos: tuple[float, float]) -> np.ndarray: """Converts a single position to a 2d numpy array with 1 row and 2 columns.""" return np.fromiter(pos, dtype=float, count=2).reshape((1, 2))
# Fast and simple calculation functions @final @staticmethod def distance_math_hypot( p1: tuple[float, float] | Point2, p2: tuple[float, float] | Point2, ) -> float: return math.hypot(p1[0] - p2[0], p1[1] - p2[1]) @final @staticmethod def distance_math_hypot_squared( p1: tuple[float, float] | Point2, p2: tuple[float, float] | Point2, ) -> float: return pow(p1[0] - p2[0], 2) + pow(p1[1] - p2[1], 2) @final def _distance_squared_unit_to_unit_method0(self, unit1: Unit, unit2: Unit) -> float: return self.distance_math_hypot_squared(unit1.position_tuple, unit2.position_tuple) # Distance calculation using the pre-calculated matrix above @final def _distance_squared_unit_to_unit_method1(self, unit1: Unit, unit2: Unit) -> float: # If checked on units if they have the same tag, return distance 0 as these are not in the 1 dimensional pdist array - would result in an error otherwise if unit1.tag == unit2.tag: return 0 # Calculate index, needs to be after pdist has been calculated and cached condensed_index = self.square_to_condensed(unit1.distance_calculation_index, unit2.distance_calculation_index) assert ( condensed_index < len(self._cached_pdist) ), f"Condensed index is larger than amount of calculated distances: {condensed_index} < {len(self._cached_pdist)}, units that caused the assert error: {unit1} and {unit2}" distance = self._pdist[condensed_index] return distance @final def _distance_squared_unit_to_unit_method2(self, unit1: Unit, unit2: Unit) -> float: # Calculate index, needs to be after cdist has been calculated and cached return self._cdist[unit1.distance_calculation_index, unit2.distance_calculation_index] # Distance calculation using the fastest distance calculation functions @final def _distance_pos_to_pos( self, pos1: tuple[float, float] | Point2, pos2: tuple[float, float] | Point2, ) -> float: return self.distance_math_hypot(pos1, pos2) @final def _distance_units_to_pos( self, units: Units, pos: tuple[float, float] | Point2, ) -> Generator[float, None, None]: """This function does not scale well, if len(units) > 100 it gets fairly slow""" return (self.distance_math_hypot(u.position_tuple, pos) for u in units) @final def _distance_unit_to_points( self, unit: Unit, points: Iterable[tuple[float, float]], ) -> Generator[float, None, None]: """This function does not scale well, if len(points) > 100 it gets fairly slow""" pos = unit.position_tuple return (self.distance_math_hypot(p, pos) for p in points) @final def _distances_override_functions(self, method: int = 0) -> None: """Overrides the internal distance calculation functions at game start in bot_ai.py self._prepare_start() function method 0: Use python's math.hypot The following methods calculate the distances between all units once: method 1: Use scipy's pdist condensed matrix (1d array) method 2: Use scipy's cidst square matrix (2d array) method 3: Use scipy's cidst square matrix (2d array) without asserts (careful: very weird error messages, but maybe slightly faster)""" assert 0 <= method <= 3, f"Selected method was: {method}" if method == 0: self._distance_squared_unit_to_unit = self._distance_squared_unit_to_unit_method0 elif method == 1: self._distance_squared_unit_to_unit = self._distance_squared_unit_to_unit_method1 self.calculate_distances = self._calculate_distances_method1 elif method == 2: self._distance_squared_unit_to_unit = self._distance_squared_unit_to_unit_method2 self.calculate_distances = self._calculate_distances_method2 elif method == 3: self._distance_squared_unit_to_unit = self._distance_squared_unit_to_unit_method2 self.calculate_distances = self._calculate_distances_method3