Source code for network_gym_client.envs.rmcat.adapter

#Copyright(C) 2023 Intel Corporation
#SPDX-License-Identifier: Apache-2.0
#File : adapter.py

import network_gym_client.adapter

import sys
from gymnasium import spaces
import numpy as np
import pandas as pd
import json
from pathlib import Path
import json
import csv

[docs]class Adapter(network_gym_client.adapter.Adapter): """rmcat env adapter. Args: Adapter (network_gym_client.adapter.Adapter): base class. """ def __init__(self, config_json): """Initialize the adapter. Args: config_json (json): the configuration file """ super().__init__(config_json) self.env = Path(__file__).resolve().parent.name self.num_features = 3 self.size_per_feature = int(self.config_json['env_config']['nada_flows']) if config_json['env_config']['env'] != self.env: sys.exit("[ERROR] wrong environment Adapter. Configured environment: " + str(config_json['env_config']['env']) + " != Launched environment: " + str(self.env)) FILE_PATH = Path(__file__).parent #Append the bw_trace to json file. data = {} # Open a csv reader called DictReader with open(FILE_PATH / config_json['env_config']['bw_trace_file'], encoding='utf-8') as csvf: csvReader = csv.DictReader(csvf) # Convert each row into a dictionary # and add it to data for rows in csvReader: for key, value in rows.items(): if key not in data: data[key] = [float(value)] else: data[key].append(float(value)) config_json['env_config']['bw_trace'] = data def get_action_space(self): """Get action space for the rmcat env. Returns: spaces: action spaces """ RMCAT_CC_DEFAULT_RMIN = 150000 # in bps: 150Kbps RMCAT_CC_DEFAULT_RMAX = 1500000. # in bps: 1.5Mbps return spaces.Box(low=RMCAT_CC_DEFAULT_RMIN, high=RMCAT_CC_DEFAULT_RMAX, shape=(self.size_per_feature,), dtype=np.float32) #consistent with the get_observation function. def get_observation_space(self): """Get the observation space for rmcat env. Returns: spaces: observation spaces """ return spaces.Box(low=0, high=1000, shape=(self.num_features, self.size_per_feature), dtype=np.float32) def get_observation(self, df): """Prepare observation for rmcat env. This function should return the same number of features defined in the :meth:`get_observation_space`. Args: df (pd.dataframe): network stats measurement Returns: spaces: observation spaces """ #print (df) row_loglen = None row_qdel = None row_rtt = None row_ploss = None row_plr = None row_xcurr = None row_rrate = None row_srate = None rtt_value = np.empty(self.size_per_feature, dtype=object) xcurr_value = np.empty(self.size_per_feature, dtype=object) rrate_value = np.empty(self.size_per_feature, dtype=object) for index, row in df.iterrows(): if row['source'] == 'rmcat': if row['name'] == 'loglen': row_loglen = row elif row['name'] == 'qdel': row_qdel = row elif row['name'] == 'rtt': row_rtt = row rtt_value = row['value'] self.action_data_format = row elif row['name'] == 'ploss': row_ploss = row elif row['name'] == 'plr': row_plr = row elif row['name'] == 'xcurr': row_xcurr = row xcurr_value = row['value'] elif row['name'] == 'rrate': row_rrate = row rrate_value = row['value'] elif row['name'] == 'srate': row_srate = row self.wandb_log_buffer_append(self.df_to_dict(row_loglen)) self.wandb_log_buffer_append(self.df_to_dict(row_qdel)) self.wandb_log_buffer_append(self.df_to_dict(row_rtt)) self.wandb_log_buffer_append(self.df_to_dict(row_ploss)) self.wandb_log_buffer_append(self.df_to_dict(row_plr)) self.wandb_log_buffer_append(self.df_to_dict(row_xcurr)) self.wandb_log_buffer_append(self.df_to_dict(row_rrate)) self.wandb_log_buffer_append(self.df_to_dict(row_srate)) observation = np.vstack([rtt_value, rrate_value, xcurr_value]) print('Observation --> ' + str(observation)) return observation def get_policy(self, action): """Prepare policy for the rmcat env. Args: action (spaces): action from the RL agent Returns: json: network policy """ # you may also check other constraints for action... e.g., min, max. policy1 = json.loads(self.action_data_format.to_json()) policy1["name"] = "srate" policy1["value"] = action.tolist() print('Action --> ' + str(policy1)) return policy1 def get_reward(self, df): """Prepare reward for the rmcat env. Args: df (pd.DataFrame): network stats Returns: spaces: reward spaces """ #TODO: add a reward function for you rmcat env reward = 0 # send info to wandb self.wandb_log_buffer_append({"reward": reward}) return reward