mutants_expanded=np.zeros(output_dim)
empty_onehot=np.zeros(output_dim[3])
if start_pos is None:
start_pos=0
if end_pos is None:
end_pos=output_dim[2]
/Ǚ. Iterate through all tasks, positions
for sample_index in range(output_dim[0]):
print("ISM: task:"+str(task_index)+" sample:"+str(sample_index))
//fill in wild type logit values into an array of dim (task,sequence_length,num_bases)
wt_logit_for_task_sample=wild_type_logits[sample_index]
wt_expanded[sample_index]=np.tile(wt_logit_for_task_sample,(output_dim[2],output_dim[3]))
//mutagenize each position
temp_batch = []
tempbatch_baseposandletter = []
for base_pos in range(start_pos,end_pos):
//for each position, iterate through the 4 bases
for base_letter in range(output_dim[3]):
cur_base=np.array(empty_onehot)
cur_base[base_letter]=1
Xtmp=np.array(X[sample_index])
Xtmp[0][base_pos]=cur_base
temp_batch.append(Xtmp)
tempbatch_baseposandletter.append((base_pos, base_letter))
//get the logits of the batch
batch_logits = preact_function([temp_batch])
for logit,(base_pos, base_letter) in zip(batch_logits, tempbatch_baseposandletter):
mutants_expanded[sample_index][0][base_pos][base_letter]=logit
After Change
for base_letter in range(output_dim[3]):
cur_base=np.array(empty_onehot)
cur_base[base_letter]=1
Xtmp=np.array(np.expand_dims(X[sample_index],axis=0))
Xtmp[0][0][base_pos]=cur_base
//get the logit of Xtmp
Xtmp_logit=np.squeeze(preact_function(Xtmp),axis=0)
mutants_expanded[sample_index][0][base_pos][base_letter]=Xtmp_logit[task_index]
//subtract wt_expanded from mutants_expanded
ism_vals=mutants_expanded-wt_expanded
//For each position subtract the mean ISM score for that position from each of the 4 values